[mlpack-git] master: Add Serialization() to NaiveBayesClassifier. (1eb721c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Nov 3 10:29:49 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/bd3fd9b46140555b3be741c1f50491b629fe9212...1eb721c663e640d571d8374c67c40ad8a5ea6fb3

>---------------------------------------------------------------

commit 1eb721c663e640d571d8374c67c40ad8a5ea6fb3
Author: ryan <ryan at ratml.org>
Date:   Tue Nov 3 10:25:05 2015 -0500

    Add Serialization() to NaiveBayesClassifier.


>---------------------------------------------------------------

1eb721c663e640d571d8374c67c40ad8a5ea6fb3
 .../methods/naive_bayes/naive_bayes_classifier.hpp |  4 ++
 .../naive_bayes/naive_bayes_classifier_impl.hpp    | 16 ++++--
 src/mlpack/tests/serialization_test.cpp            | 61 ++++++++++++++++++++++
 3 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 51fdf4f..4b6e40b 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -135,6 +135,10 @@ class NaiveBayesClassifier
   //! Modify the prior probabilities for each class.
   arma::vec& Probabilities() { return probabilities; }
 
+  //! Serialize the classifier.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   //! Sample mean for each class.
   MatType means;
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 21829e3..2b0b3c7 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -189,7 +189,7 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
       exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j)));
 
     testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) *
-        pow(det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
+        std::pow(arma::det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
   }
 
   // Now calculate the label.
@@ -206,7 +206,17 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
   return;
 }
 
-}; // namespace naive_bayes
-}; // namespace mlpack
+template<typename MatType>
+template<typename Archive>
+void NaiveBayesClassifier<MatType>::Serialize(Archive& ar,
+                                              const unsigned int /* version */)
+{
+  ar & data::CreateNVP(means, "means");
+  ar & data::CreateNVP(variances, "variances");
+  ar & data::CreateNVP(probabilities, "probabilities");
+}
+
+} // namespace naive_bayes
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 3baf714..fddb368 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -29,6 +29,7 @@
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 #include <mlpack/methods/softmax_regression/softmax_regression.hpp>
 #include <mlpack/methods/det/dtree.hpp>
+#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -38,6 +39,8 @@ using namespace mlpack::metric;
 using namespace mlpack::tree;
 using namespace mlpack::perceptron;
 using namespace mlpack::regression;
+using namespace mlpack::naive_bayes;
+
 using namespace arma;
 using namespace boost;
 using namespace boost::archive;
@@ -1257,4 +1260,62 @@ BOOST_AUTO_TEST_CASE(DETTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(NaiveBayesSerializationTest)
+{
+  // Train NBC randomly.  Make sure the model is the same after serializing and
+  // re-loading.
+  arma::mat dataset;
+  dataset.randu(10, 500);
+  arma::Row<size_t> labels(500);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    if (dataset(0, i) > 0.5)
+      labels[i] = 0;
+    else
+      labels[i] = 1;
+  }
+
+  NaiveBayesClassifier<> nbc(dataset, labels, 2);
+
+  // Initialize some empty Naive Bayes classifiers.
+  NaiveBayesClassifier<> xmlNbc(0, 0), textNbc(0, 0), binaryNbc(0, 0);
+  SerializeObjectAll(nbc, xmlNbc, textNbc, binaryNbc);
+
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, xmlNbc.Means().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, textNbc.Means().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, binaryNbc.Means().n_elem);
+  for (size_t i = 0; i < nbc.Means().n_elem; ++i)
+  {
+    BOOST_REQUIRE_CLOSE(nbc.Means()[i], xmlNbc.Means()[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Means()[i], textNbc.Means()[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Means()[i], binaryNbc.Means()[i], 1e-5);
+  }
+
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, xmlNbc.Variances().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, textNbc.Variances().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, binaryNbc.Variances().n_elem);
+  for (size_t i = 0; i < nbc.Variances().n_elem; ++i)
+  {
+    BOOST_REQUIRE_CLOSE(nbc.Variances()[i], xmlNbc.Variances()[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Variances()[i], textNbc.Variances()[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Variances()[i], binaryNbc.Variances()[i], 1e-5);
+  }
+
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+      xmlNbc.Probabilities().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+      textNbc.Probabilities().n_elem);
+  BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+      binaryNbc.Probabilities().n_elem);
+  for (size_t i = 0; i < nbc.Probabilities().n_elem; ++i)
+  {
+    BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], xmlNbc.Probabilities()[i],
+        1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], textNbc.Probabilities()[i],
+        1e-5);
+    BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], binaryNbc.Probabilities()[i],
+        1e-5);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list