[mlpack-git] master, mlpack-1.0.x: Test incremental variance functionality. (5bf6ae3)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:46:25 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 5bf6ae3297b9b4cf05b747e8d7f4d49626ef881c
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Apr 15 15:24:09 2014 +0000
Test incremental variance functionality.
>---------------------------------------------------------------
5bf6ae3297b9b4cf05b747e8d7f4d49626ef881c
src/mlpack/tests/nbc_test.cpp | 55 +++++++++++++++++++++++++++++++++++++++++++
1 file changed, 55 insertions(+)
diff --git a/src/mlpack/tests/nbc_test.cpp b/src/mlpack/tests/nbc_test.cpp
index 7d66fbe..a8b6ad7 100644
--- a/src/mlpack/tests/nbc_test.cpp
+++ b/src/mlpack/tests/nbc_test.cpp
@@ -67,4 +67,59 @@ BOOST_AUTO_TEST_CASE(NaiveBayesClassifierTest)
BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
}
+// The same test, but this one uses the incremental algorithm to calculate
+// variance.
+BOOST_AUTO_TEST_CASE(NaiveBayesClassifierIncrementalTest)
+{
+ const char* trainFilename = "trainSet.csv";
+ const char* testFilename = "testSet.csv";
+ const char* trainResultFilename = "trainRes.csv";
+ const char* testResultFilename = "testRes.csv";
+ size_t classes = 2;
+
+ arma::mat trainData, trainRes, calcMat;
+ data::Load(trainFilename, trainData, true);
+ data::Load(trainResultFilename, trainRes, true);
+
+ // Get the labels out.
+ arma::Col<size_t> labels(trainData.n_cols);
+ for (size_t i = 0; i < trainData.n_cols; ++i)
+ labels[i] = trainData(trainData.n_rows - 1, i);
+ trainData.shed_row(trainData.n_rows - 1);
+
+ NaiveBayesClassifier<> nbcTest(trainData, labels, classes, true);
+
+ size_t dimension = nbcTest.Means().n_rows;
+ calcMat.zeros(2 * dimension + 1, classes);
+
+ for (size_t i = 0; i < dimension; i++)
+ {
+ for (size_t j = 0; j < classes; j++)
+ {
+ calcMat(i, j) = nbcTest.Means()(i, j);
+ calcMat(i + dimension, j) = nbcTest.Variances()(i, j);
+ }
+ }
+
+ for (size_t i = 0; i < classes; i++)
+ calcMat(2 * dimension, i) = nbcTest.Probabilities()(i);
+
+ for (size_t i = 0; i < calcMat.n_rows; i++)
+ for (size_t j = 0; j < classes; j++)
+ BOOST_REQUIRE_CLOSE(trainRes(i, j) + .00001, calcMat(i, j), 0.01);
+
+ arma::mat testData;
+ arma::Mat<size_t> testRes;
+ arma::Col<size_t> calcVec;
+ data::Load(testFilename, testData, true);
+ data::Load(testResultFilename, testRes, true);
+
+ testData.shed_row(testData.n_rows - 1); // Remove the labels.
+
+ nbcTest.Classify(testData, calcVec);
+
+ for (size_t i = 0; i < testData.n_cols; i++)
+ BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list