[mlpack-git] master: Add tests for Naive Bayes classifier. (3464c8e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:47 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit 3464c8ef7b0f57758b07fdba011b3062b81d5d94
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Sep 26 04:17:15 2015 +0000

    Add tests for Naive Bayes classifier.


>---------------------------------------------------------------

3464c8ef7b0f57758b07fdba011b3062b81d5d94
 src/mlpack/tests/nbc_test.cpp | 163 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 163 insertions(+)

diff --git a/src/mlpack/tests/nbc_test.cpp b/src/mlpack/tests/nbc_test.cpp
index 4fdec14..a3e43cd 100644
--- a/src/mlpack/tests/nbc_test.cpp
+++ b/src/mlpack/tests/nbc_test.cpp
@@ -122,4 +122,167 @@ BOOST_AUTO_TEST_CASE(NaiveBayesClassifierIncrementalTest)
     BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
 }
 
+/**
+ * Ensure that separate training gives the same model.
+ */
+BOOST_AUTO_TEST_CASE(SeparateTrainTest)
+{
+  const char* trainFilename = "trainSet.csv";
+  const char* trainResultFilename = "trainRes.csv";
+  size_t classes = 2;
+
+  arma::mat trainData, trainRes, calcMat;
+  data::Load(trainFilename, trainData, true);
+  data::Load(trainResultFilename, trainRes, true);
+
+  // Get the labels out.
+  arma::Row<size_t> labels(trainData.n_cols);
+  for (size_t i = 0; i < trainData.n_cols; ++i)
+    labels[i] = trainData(trainData.n_rows - 1, i);
+  trainData.shed_row(trainData.n_rows - 1);
+
+  NaiveBayesClassifier<> nbc(trainData, labels, classes, true);
+  NaiveBayesClassifier<> nbcTrain(trainData.n_rows, classes);
+  nbcTrain.Train(trainData, labels, false);
+
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_rows, nbcTrain.Means().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_cols, nbcTrain.Means().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_rows, nbcTrain.Variances().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_cols, nbcTrain.Variances().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+                      nbcTrain.Probabilities().n_elem);
+
+  for (size_t i = 0; i < nbc.Means().n_elem; ++i)
+  {
+    if (std::abs(nbc.Means()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Means()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Means()[i], nbcTrain.Means()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Variances().n_elem; ++i)
+  {
+    if (std::abs(nbc.Variances()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Variances()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Variances()[i], nbcTrain.Variances()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Probabilities().n_elem; ++i)
+  {
+    if (std::abs(nbc.Probabilities()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Probabilities()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], nbcTrain.Probabilities()[i],
+          1e-5);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(SeparateTrainIncrementalTest)
+{
+  const char* trainFilename = "trainSet.csv";
+  const char* trainResultFilename = "trainRes.csv";
+  size_t classes = 2;
+
+  arma::mat trainData, trainRes, calcMat;
+  data::Load(trainFilename, trainData, true);
+  data::Load(trainResultFilename, trainRes, true);
+
+  // Get the labels out.
+  arma::Row<size_t> labels(trainData.n_cols);
+  for (size_t i = 0; i < trainData.n_cols; ++i)
+    labels[i] = trainData(trainData.n_rows - 1, i);
+  trainData.shed_row(trainData.n_rows - 1);
+
+  NaiveBayesClassifier<> nbc(trainData, labels, classes, true);
+  NaiveBayesClassifier<> nbcTrain(trainData.n_rows, classes);
+  nbcTrain.Train(trainData, labels, true);
+
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_rows, nbcTrain.Means().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_cols, nbcTrain.Means().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_rows, nbcTrain.Variances().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_cols, nbcTrain.Variances().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+                      nbcTrain.Probabilities().n_elem);
+
+  for (size_t i = 0; i < nbc.Means().n_elem; ++i)
+  {
+    if (std::abs(nbc.Means()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Means()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Means()[i], nbcTrain.Means()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Variances().n_elem; ++i)
+  {
+    if (std::abs(nbc.Variances()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Variances()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Variances()[i], nbcTrain.Variances()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Probabilities().n_elem; ++i)
+  {
+    if (std::abs(nbc.Probabilities()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Probabilities()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], nbcTrain.Probabilities()[i],
+          1e-5);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(SeparateTrainIndividualIncrementalTest)
+{
+  const char* trainFilename = "trainSet.csv";
+  const char* trainResultFilename = "trainRes.csv";
+  size_t classes = 2;
+
+  arma::mat trainData, trainRes, calcMat;
+  data::Load(trainFilename, trainData, true);
+  data::Load(trainResultFilename, trainRes, true);
+
+  // Get the labels out.
+  arma::Row<size_t> labels(trainData.n_cols);
+  for (size_t i = 0; i < trainData.n_cols; ++i)
+    labels[i] = trainData(trainData.n_rows - 1, i);
+  trainData.shed_row(trainData.n_rows - 1);
+
+  NaiveBayesClassifier<> nbc(trainData, labels, classes, true);
+  NaiveBayesClassifier<> nbcTrain(trainData.n_rows, classes);
+  for (size_t i = 0; i < trainData.n_cols; ++i)
+    nbcTrain.Train(trainData.col(i), labels[i]);
+
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_rows, nbcTrain.Means().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_cols, nbcTrain.Means().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_rows, nbcTrain.Variances().n_rows);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_cols, nbcTrain.Variances().n_cols);
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+                      nbcTrain.Probabilities().n_elem);
+
+  for (size_t i = 0; i < nbc.Means().n_elem; ++i)
+  {
+    if (std::abs(nbc.Means()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Means()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Means()[i], nbcTrain.Means()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Variances().n_elem; ++i)
+  {
+    if (std::abs(nbc.Variances()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Variances()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Variances()[i], nbcTrain.Variances()[i], 1e-5);
+  }
+
+  for (size_t i = 0; i < nbc.Probabilities().n_elem; ++i)
+  {
+    if (std::abs(nbc.Probabilities()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(nbcTrain.Probabilities()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], nbcTrain.Probabilities()[i],
+          1e-5);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list