[mlpack-git] master: Add Serialize() to LogisticRegression. (763f54c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 12:05:42 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/aef36197cd6ed335ab4a9b60e065a5044e192e5d...763f54cf37ffd71871f9de1f59078f0c9235aff6

>---------------------------------------------------------------

commit 763f54cf37ffd71871f9de1f59078f0c9235aff6
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 11 16:05:26 2015 +0000

    Add Serialize() to LogisticRegression.


>---------------------------------------------------------------

763f54cf37ffd71871f9de1f59078f0c9235aff6
 .../logistic_regression/logistic_regression.hpp    |  6 +++++-
 .../logistic_regression_impl.hpp                   | 16 +++++++++++---
 src/mlpack/tests/serialization_test.cpp            | 25 ++++++++++++++++++++++
 3 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/logistic_regression/logistic_regression.hpp b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
index 2390af8..506678f 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
@@ -130,7 +130,11 @@ class LogisticRegression
   double ComputeError(const arma::mat& predictors,
                       const arma::vec& responses) const;
 
-  // Returns a string representation of this object.
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Returns a string representation of this object.
   std::string ToString() const;
 
  private:
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
index 26b3991..11bfbb7 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
@@ -93,7 +93,7 @@ void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
       + (1.0 - decisionBoundary));
 }
 
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
 double LogisticRegression<OptimizerType>::ComputeError(
     const arma::mat& predictors,
     const arma::vec& responses) const
@@ -105,7 +105,7 @@ double LogisticRegression<OptimizerType>::ComputeError(
   return newErrorFunction.Evaluate(parameters);
 }
 
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
 double LogisticRegression<OptimizerType>::ComputeAccuracy(
     const arma::mat& predictors,
     const arma::vec& responses,
@@ -124,7 +124,17 @@ double LogisticRegression<OptimizerType>::ComputeAccuracy(
   return (double) (count * 100) / responses.n_rows;
 }
 
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
+template<typename Archive>
+void LogisticRegression<OptimizerType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  ar & data::CreateNVP(parameters, "parameters");
+  ar & data::CreateNVP(lambda, "lambda");
+}
+
+template<template<typename> class OptimizerType>
 std::string LogisticRegression<OptimizerType>::ToString() const
 {
   std::ostringstream convert;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index abdc4d4..df259be 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -23,6 +23,7 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 
 #include <mlpack/methods/perceptron/perceptron.hpp>
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -31,6 +32,7 @@ using namespace mlpack::bound;
 using namespace mlpack::metric;
 using namespace mlpack::tree;
 using namespace mlpack::perceptron;
+using namespace mlpack::regression;
 using namespace arma;
 using namespace boost;
 using namespace boost::archive;
@@ -720,4 +722,27 @@ BOOST_AUTO_TEST_CASE(PerceptronTest)
   BOOST_REQUIRE_EQUAL(p.MaxIterations(), pBinary.MaxIterations());
 }
 
+BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
+{
+  arma::mat data;
+  data.randu(3, 100);
+  arma::vec responses;
+  responses.randu(100);
+
+  LogisticRegression<> lr(data, responses, 0.5);
+
+  LogisticRegression<> lrXml(arma::vec(), 0.3);
+  LogisticRegression<> lrText(data, responses + 1);
+  LogisticRegression<> lrBinary(arma::vec("1 2 3"), 0.0);
+
+  SerializeObjectAll(lr, lrXml, lrText, lrBinary);
+
+  CheckMatrices(lr.Parameters(), lrXml.Parameters(), lrText.Parameters(),
+      lrBinary.Parameters());
+
+  BOOST_REQUIRE_CLOSE(lr.Lambda(), lrXml.Lambda(), 1e-5);
+  BOOST_REQUIRE_CLOSE(lr.Lambda(), lrText.Lambda(), 1e-5);
+  BOOST_REQUIRE_CLOSE(lr.Lambda(), lrBinary.Lambda(), 1e-5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list