[mlpack-git] master: Add Serialize() to LogisticRegression. (763f54c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 12:05:42 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/aef36197cd6ed335ab4a9b60e065a5044e192e5d...763f54cf37ffd71871f9de1f59078f0c9235aff6
>---------------------------------------------------------------
commit 763f54cf37ffd71871f9de1f59078f0c9235aff6
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 11 16:05:26 2015 +0000
Add Serialize() to LogisticRegression.
>---------------------------------------------------------------
763f54cf37ffd71871f9de1f59078f0c9235aff6
.../logistic_regression/logistic_regression.hpp | 6 +++++-
.../logistic_regression_impl.hpp | 16 +++++++++++---
src/mlpack/tests/serialization_test.cpp | 25 ++++++++++++++++++++++
3 files changed, 43 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression.hpp b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
index 2390af8..506678f 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
@@ -130,7 +130,11 @@ class LogisticRegression
double ComputeError(const arma::mat& predictors,
const arma::vec& responses) const;
- // Returns a string representation of this object.
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ //! Returns a string representation of this object.
std::string ToString() const;
private:
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
index 26b3991..11bfbb7 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
@@ -93,7 +93,7 @@ void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
+ (1.0 - decisionBoundary));
}
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
double LogisticRegression<OptimizerType>::ComputeError(
const arma::mat& predictors,
const arma::vec& responses) const
@@ -105,7 +105,7 @@ double LogisticRegression<OptimizerType>::ComputeError(
return newErrorFunction.Evaluate(parameters);
}
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
double LogisticRegression<OptimizerType>::ComputeAccuracy(
const arma::mat& predictors,
const arma::vec& responses,
@@ -124,7 +124,17 @@ double LogisticRegression<OptimizerType>::ComputeAccuracy(
return (double) (count * 100) / responses.n_rows;
}
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
+template<typename Archive>
+void LogisticRegression<OptimizerType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(parameters, "parameters");
+ ar & data::CreateNVP(lambda, "lambda");
+}
+
+template<template<typename> class OptimizerType>
std::string LogisticRegression<OptimizerType>::ToString() const
{
std::ostringstream convert;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index abdc4d4..df259be 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -23,6 +23,7 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/methods/perceptron/perceptron.hpp>
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -31,6 +32,7 @@ using namespace mlpack::bound;
using namespace mlpack::metric;
using namespace mlpack::tree;
using namespace mlpack::perceptron;
+using namespace mlpack::regression;
using namespace arma;
using namespace boost;
using namespace boost::archive;
@@ -720,4 +722,27 @@ BOOST_AUTO_TEST_CASE(PerceptronTest)
BOOST_REQUIRE_EQUAL(p.MaxIterations(), pBinary.MaxIterations());
}
+BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
+{
+ arma::mat data;
+ data.randu(3, 100);
+ arma::vec responses;
+ responses.randu(100);
+
+ LogisticRegression<> lr(data, responses, 0.5);
+
+ LogisticRegression<> lrXml(arma::vec(), 0.3);
+ LogisticRegression<> lrText(data, responses + 1);
+ LogisticRegression<> lrBinary(arma::vec("1 2 3"), 0.0);
+
+ SerializeObjectAll(lr, lrXml, lrText, lrBinary);
+
+ CheckMatrices(lr.Parameters(), lrXml.Parameters(), lrText.Parameters(),
+ lrBinary.Parameters());
+
+ BOOST_REQUIRE_CLOSE(lr.Lambda(), lrXml.Lambda(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lr.Lambda(), lrText.Lambda(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lr.Lambda(), lrBinary.Lambda(), 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list