[mlpack-git] master: Add Serialization() to NaiveBayesClassifier. (1eb721c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Nov 3 10:29:49 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/bd3fd9b46140555b3be741c1f50491b629fe9212...1eb721c663e640d571d8374c67c40ad8a5ea6fb3
>---------------------------------------------------------------
commit 1eb721c663e640d571d8374c67c40ad8a5ea6fb3
Author: ryan <ryan at ratml.org>
Date: Tue Nov 3 10:25:05 2015 -0500
Add Serialization() to NaiveBayesClassifier.
>---------------------------------------------------------------
1eb721c663e640d571d8374c67c40ad8a5ea6fb3
.../methods/naive_bayes/naive_bayes_classifier.hpp | 4 ++
.../naive_bayes/naive_bayes_classifier_impl.hpp | 16 ++++--
src/mlpack/tests/serialization_test.cpp | 61 ++++++++++++++++++++++
3 files changed, 78 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 51fdf4f..4b6e40b 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -135,6 +135,10 @@ class NaiveBayesClassifier
//! Modify the prior probabilities for each class.
arma::vec& Probabilities() { return probabilities; }
+ //! Serialize the classifier.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! Sample mean for each class.
MatType means;
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 21829e3..2b0b3c7 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -189,7 +189,7 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j)));
testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) *
- pow(det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
+ std::pow(arma::det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
}
// Now calculate the label.
@@ -206,7 +206,17 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
return;
}
-}; // namespace naive_bayes
-}; // namespace mlpack
+template<typename MatType>
+template<typename Archive>
+void NaiveBayesClassifier<MatType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(means, "means");
+ ar & data::CreateNVP(variances, "variances");
+ ar & data::CreateNVP(probabilities, "probabilities");
+}
+
+} // namespace naive_bayes
+} // namespace mlpack
#endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 3baf714..fddb368 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -29,6 +29,7 @@
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
#include <mlpack/methods/det/dtree.hpp>
+#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -38,6 +39,8 @@ using namespace mlpack::metric;
using namespace mlpack::tree;
using namespace mlpack::perceptron;
using namespace mlpack::regression;
+using namespace mlpack::naive_bayes;
+
using namespace arma;
using namespace boost;
using namespace boost::archive;
@@ -1257,4 +1260,62 @@ BOOST_AUTO_TEST_CASE(DETTest)
}
}
+BOOST_AUTO_TEST_CASE(NaiveBayesSerializationTest)
+{
+ // Train NBC randomly. Make sure the model is the same after serializing and
+ // re-loading.
+ arma::mat dataset;
+ dataset.randu(10, 500);
+ arma::Row<size_t> labels(500);
+ for (size_t i = 0; i < 500; ++i)
+ {
+ if (dataset(0, i) > 0.5)
+ labels[i] = 0;
+ else
+ labels[i] = 1;
+ }
+
+ NaiveBayesClassifier<> nbc(dataset, labels, 2);
+
+ // Initialize some empty Naive Bayes classifiers.
+ NaiveBayesClassifier<> xmlNbc(0, 0), textNbc(0, 0), binaryNbc(0, 0);
+ SerializeObjectAll(nbc, xmlNbc, textNbc, binaryNbc);
+
+ BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, xmlNbc.Means().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, textNbc.Means().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Means().n_elem, binaryNbc.Means().n_elem);
+ for (size_t i = 0; i < nbc.Means().n_elem; ++i)
+ {
+ BOOST_REQUIRE_CLOSE(nbc.Means()[i], xmlNbc.Means()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Means()[i], textNbc.Means()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Means()[i], binaryNbc.Means()[i], 1e-5);
+ }
+
+ BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, xmlNbc.Variances().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, textNbc.Variances().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Variances().n_elem, binaryNbc.Variances().n_elem);
+ for (size_t i = 0; i < nbc.Variances().n_elem; ++i)
+ {
+ BOOST_REQUIRE_CLOSE(nbc.Variances()[i], xmlNbc.Variances()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Variances()[i], textNbc.Variances()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Variances()[i], binaryNbc.Variances()[i], 1e-5);
+ }
+
+ BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+ xmlNbc.Probabilities().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+ textNbc.Probabilities().n_elem);
+ BOOST_REQUIRE_EQUAL(nbc.Probabilities().n_elem,
+ binaryNbc.Probabilities().n_elem);
+ for (size_t i = 0; i < nbc.Probabilities().n_elem; ++i)
+ {
+ BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], xmlNbc.Probabilities()[i],
+ 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], textNbc.Probabilities()[i],
+ 1e-5);
+ BOOST_REQUIRE_CLOSE(nbc.Probabilities()[i], binaryNbc.Probabilities()[i],
+ 1e-5);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list