[mlpack-git] master: Refactor so that OptimizerType is only needed sometimes. (600b2f5)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Sep 16 14:29:11 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/bbe9cd161571c99aca88096b07de61341711c049...e67787e336136a9e46b2d502bd583b8aea2668a4
>---------------------------------------------------------------
commit 600b2f5d524e87616a6d72198e8bfc3603f7a013
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 16 02:50:12 2015 +0000
Refactor so that OptimizerType is only needed sometimes.
OptimizerType isn't relevant to an already-trained model; so it should only be specified during training.
>---------------------------------------------------------------
600b2f5d524e87616a6d72198e8bfc3603f7a013
.../logistic_regression/logistic_regression.hpp | 67 ++++++++++++++-----
.../logistic_regression_impl.hpp | 78 +++++++++++-----------
src/mlpack/tests/logistic_regression_test.cpp | 10 +--
src/mlpack/tests/serialization_test.cpp | 4 +-
4 files changed, 98 insertions(+), 61 deletions(-)
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression.hpp b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
index 417dedf..8ca67f5 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
@@ -16,9 +16,7 @@
namespace mlpack {
namespace regression {
-template<
- template<typename> class OptimizerType = mlpack::optimization::L_BFGS
->
+template<typename MatType = arma::mat>
class LogisticRegression
{
public:
@@ -32,7 +30,10 @@ class LogisticRegression
* @param responses Outputs resulting from input training variables.
* @param lambda L2-regularization parameter.
*/
- LogisticRegression(const arma::mat& predictors,
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ LogisticRegression(const MatType& predictors,
const arma::Row<size_t>& responses,
const double lambda = 0);
@@ -47,12 +48,31 @@ class LogisticRegression
* @param initialPoint Initial model to train with.
* @param lambda L2-regularization parameter.
*/
- LogisticRegression(const arma::mat& predictors,
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ LogisticRegression(const MatType& predictors,
const arma::Row<size_t>& responses,
const arma::vec& initialPoint,
const double lambda = 0);
/**
+ * Construct the LogisticRegression class without performing any training.
+ * The dimensionality of the data (which will be used to set the size of the
+ * parameters vector) must be specified, and all of the parameters in the
+ * model will be set to 0. Note that the dimensionality may be changed later
+ * by directly modifying the parameters vector (using Parameters()).
+ *
+ * @param dimensionality Dimensionality of the data.
+ * @param lambda L2-regularization parameter.
+ */
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ LogisticRegression(const size_t dimensionality,
+ const double lambda = 0);
+
+ /**
* Construct the LogisticRegression class with the given labeled training
* data. This will train the model. This overload takes an already
* instantiated optimizer (which holds the LogisticRegressionFunction error
@@ -63,18 +83,33 @@ class LogisticRegression
*
* @param optimizer Instantiated optimizer with instantiated error function.
*/
- LogisticRegression(OptimizerType<LogisticRegressionFunction<>>& optimizer);
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ LogisticRegression(
+ OptimizerType<LogisticRegressionFunction<MatType>>& optimizer);
/**
- * Construct a logistic regression model from the given parameters, without
- * performing any training. The lambda parameter is used for the
- * ComputeAccuracy() and ComputeError() functions; this constructor does not
- * train the model itself.
+ * Train the LogisticRegression model on the given input data.
+ */
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ void Train(const MatType& predictors,
+ const arma::Row<size_t>& responses,
+ const MatType& initialPoint);
+
+ /**
+ * Train the LogisticRegression model with the given instantiated optimizer.
+ * Using this overload allows configuring the instantiated optimizer before
+ * training is performed.
*
- * @param parameters Parameters making up the model.
- * @param lambda L2-regularization penalty parameter.
+ * @param optimizer Instantiated optimizer with instantiated error function.
*/
- LogisticRegression(const arma::vec& parameters, const double lambda = 0);
+ template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+ >
+ void Train(OptimizerType<LogisticRegressionFunction<MatType>>& optimizer);
//! Return the parameters (the b vector).
const arma::vec& Parameters() const { return parameters; }
@@ -97,7 +132,7 @@ class LogisticRegression
* @param responses Vector to put output predictions of responses into.
* @param decisionBoundary Decision boundary (default 0.5).
*/
- void Predict(const arma::mat& predictors,
+ void Predict(const MatType& predictors,
arma::Row<size_t>& responses,
const double decisionBoundary = 0.5) const;
@@ -115,7 +150,7 @@ class LogisticRegression
* @param decisionBoundary Decision boundary (default 0.5).
* @return Percentage of responses that are predicted correctly.
*/
- double ComputeAccuracy(const arma::mat& predictors,
+ double ComputeAccuracy(const MatType& predictors,
const arma::Row<size_t>& responses,
const double decisionBoundary = 0.5) const;
@@ -127,7 +162,7 @@ class LogisticRegression
* @param predictors Input predictors.
* @param responses Vector of responses.
*/
- double ComputeError(const arma::mat& predictors,
+ double ComputeError(const MatType& predictors,
const arma::Row<size_t>& responses) const;
//! Serialize the model.
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
index 9d36647..608dfde 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
@@ -14,16 +14,18 @@
namespace mlpack {
namespace regression {
+template<typename MatType>
template<template<typename> class OptimizerType>
-LogisticRegression<OptimizerType>::LogisticRegression(
- const arma::mat& predictors,
+LogisticRegression<MatType>::LogisticRegression(
+ const MatType& predictors,
const arma::Row<size_t>& responses,
const double lambda) :
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
lambda(lambda)
{
- LogisticRegressionFunction<> errorFunction(predictors, responses, lambda);
- OptimizerType<LogisticRegressionFunction<>> optimizer(errorFunction);
+ LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
+ lambda);
+ OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
// Train the model.
Timer::Start("logistic_regression_optimization");
@@ -34,18 +36,20 @@ LogisticRegression<OptimizerType>::LogisticRegression(
<< "trained model is " << out << "." << std::endl;
}
+template<typename MatType>
template<template<typename> class OptimizerType>
-LogisticRegression<OptimizerType>::LogisticRegression(
- const arma::mat& predictors,
+LogisticRegression<MatType>::LogisticRegression(
+ const MatType& predictors,
const arma::Row<size_t>& responses,
const arma::vec& initialPoint,
const double lambda) :
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
lambda(lambda)
{
- LogisticRegressionFunction<> errorFunction(predictors, responses, lambda);
+ LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
+ lambda);
errorFunction.InitialPoint() = initialPoint;
- OptimizerType<LogisticRegressionFunction<>> optimizer(errorFunction);
+ OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
// Train the model.
Timer::Start("logistic_regression_optimization");
@@ -56,9 +60,21 @@ LogisticRegression<OptimizerType>::LogisticRegression(
<< "trained model is " << out << "." << std::endl;
}
+template<typename MatType>
+template<template<typename> class OptimizerType>
+LogisticRegression<MatType>::LogisticRegression(
+ const size_t dimensionality,
+ const double lambda) :
+ parameters(dimensionality),
+ lambda(lambda)
+{
+ // No training to do here.
+}
+
+template<typename MatType>
template<template<typename> class OptimizerType>
-LogisticRegression<OptimizerType>::LogisticRegression(
- OptimizerType<LogisticRegressionFunction<>>& optimizer) :
+LogisticRegression<MatType>::LogisticRegression(
+ OptimizerType<LogisticRegressionFunction<MatType>>& optimizer) :
parameters(optimizer.Function().GetInitialPoint()),
lambda(optimizer.Function().Lambda())
{
@@ -70,21 +86,10 @@ LogisticRegression<OptimizerType>::LogisticRegression(
<< "trained model is " << out << "." << std::endl;
}
-template<template<typename> class OptimizerType>
-LogisticRegression<OptimizerType>::LogisticRegression(
- const arma::vec& parameters,
- const double lambda) :
- parameters(parameters),
- lambda(lambda)
-{
- // Nothing to do.
-}
-
-template<template<typename> class OptimizerType>
-void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
- arma::Row<size_t>& responses,
- const double decisionBoundary)
- const
+template<typename MatType>
+void LogisticRegression<MatType>::Predict(const MatType& predictors,
+ arma::Row<size_t>& responses,
+ const double decisionBoundary) const
{
// Calculate sigmoid function for each point. The (1.0 - decisionBoundary)
// term correctly sets an offset so that floor() returns 0 or 1 correctly.
@@ -94,9 +99,9 @@ void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
(1.0 - decisionBoundary));
}
-template<template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::ComputeError(
- const arma::mat& predictors,
+template<typename MatType>
+double LogisticRegression<MatType>::ComputeError(
+ const MatType& predictors,
const arma::Row<size_t>& responses) const
{
// Construct a new error function.
@@ -106,9 +111,9 @@ double LogisticRegression<OptimizerType>::ComputeError(
return newErrorFunction.Evaluate(parameters);
}
-template<template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::ComputeAccuracy(
- const arma::mat& predictors,
+template<typename MatType>
+double LogisticRegression<MatType>::ComputeAccuracy(
+ const MatType& predictors,
const arma::Row<size_t>& responses,
const double decisionBoundary) const
{
@@ -122,17 +127,14 @@ double LogisticRegression<OptimizerType>::ComputeAccuracy(
{
if (responses(i) == tempResponses(i))
count++;
- else
- std::cout << "i " << i << ": " << responses[i] << " vs. predicted " <<
-tempResponses(i) << ".\n";
}
return (double) (count * 100) / responses.n_elem;
}
-template<template<typename> class OptimizerType>
+template<typename MatType>
template<typename Archive>
-void LogisticRegression<OptimizerType>::Serialize(
+void LogisticRegression<MatType>::Serialize(
Archive& ar,
const unsigned int /* version */)
{
@@ -140,8 +142,8 @@ void LogisticRegression<OptimizerType>::Serialize(
ar & data::CreateNVP(lambda, "lambda");
}
-template<template<typename> class OptimizerType>
-std::string LogisticRegression<OptimizerType>::ToString() const
+template<typename MatType>
+std::string LogisticRegression<MatType>::ToString() const
{
std::ostringstream convert;
convert << "Logistic Regression [" << this << "]" << std::endl;
diff --git a/src/mlpack/tests/logistic_regression_test.cpp b/src/mlpack/tests/logistic_regression_test.cpp
index 2e633ca..1ddf9aa 100644
--- a/src/mlpack/tests/logistic_regression_test.cpp
+++ b/src/mlpack/tests/logistic_regression_test.cpp
@@ -504,7 +504,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionSGDSimpleTest)
// smaller tolerance.
LogisticRegressionFunction<> lrf(data, responses, 0.001);
SGD<LogisticRegressionFunction<>> sgd(lrf, 0.005, 500000, 1e-10);
- LogisticRegression<SGD> lr(sgd);
+ LogisticRegression<> lr(sgd);
// Test sigmoid function.
arma::vec sigmoids = 1 / (1 + arma::exp(-lr.Parameters()[0]
@@ -553,7 +553,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionSGDRegularizationSimpleTest)
// tolerance.
LogisticRegressionFunction<> lrf(data, responses, 0.001);
SGD<LogisticRegressionFunction<>> sgd(lrf, 0.005, 500000, 1e-10);
- LogisticRegression<SGD> lr(sgd);
+ LogisticRegression<> lr(sgd);
// Test sigmoid function.
arma::vec sigmoids = 1 / (1 + arma::exp(-lr.Parameters()[0]
@@ -635,7 +635,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionSGDGaussianTest)
}
// Now train a logistic regression object on it.
- LogisticRegression<SGD> lr(data, responses, 0.5);
+ LogisticRegression<> lr(data, responses, 0.5);
// Ensure that the error is close to zero.
const double acc = lr.ComputeAccuracy(data, responses);
@@ -674,7 +674,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionInstantiatedOptimizer)
LogisticRegressionFunction<> lrf(data, responses, 0.0005);
L_BFGS<LogisticRegressionFunction<>> lbfgsOpt(lrf);
lbfgsOpt.MinGradientNorm() = 1e-50;
- LogisticRegression<L_BFGS> lr(lbfgsOpt);
+ LogisticRegression<> lr(lbfgsOpt);
// Test sigmoid function.
arma::vec sigmoids = 1 / (1 + arma::exp(-lr.Parameters()[0]
@@ -689,7 +689,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionInstantiatedOptimizer)
SGD<LogisticRegressionFunction<>> sgdOpt(lrf);
sgdOpt.StepSize() = 0.15;
sgdOpt.Tolerance() = 1e-75;
- LogisticRegression<SGD> lr2(sgdOpt);
+ LogisticRegression<> lr2(sgdOpt);
// Test sigmoid function.
sigmoids = 1 / (1 + arma::exp(-lr2.Parameters()[0]
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 961f0fa..3f2d70e 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -731,9 +731,9 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
LogisticRegression<> lr(data, responses, 0.5);
- LogisticRegression<> lrXml(arma::vec(), 0.3);
+ LogisticRegression<> lrXml(data, responses + 3, 0.3);
LogisticRegression<> lrText(data, responses + 1);
- LogisticRegression<> lrBinary(arma::vec("1 2 3"), 0.0);
+ LogisticRegression<> lrBinary(3, 0.0);
SerializeObjectAll(lr, lrXml, lrText, lrBinary);
More information about the mlpack-git
mailing list