[mlpack-git] master: Add and implement Train() methods. (b34bb0f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Sep 16 14:29:13 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/bbe9cd161571c99aca88096b07de61341711c049...e67787e336136a9e46b2d502bd583b8aea2668a4
>---------------------------------------------------------------
commit b34bb0f90d5664e73db082612372722ebc021524
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 16 03:30:45 2015 +0000
Add and implement Train() methods.
>---------------------------------------------------------------
b34bb0f90d5664e73db082612372722ebc021524
.../logistic_regression/logistic_regression.hpp | 10 +++-
.../logistic_regression_impl.hpp | 57 +++++++++++++---------
2 files changed, 41 insertions(+), 26 deletions(-)
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression.hpp b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
index 8ca67f5..a0cdf8f 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
@@ -96,14 +96,20 @@ class LogisticRegression
template<typename> class OptimizerType = mlpack::optimization::L_BFGS
>
void Train(const MatType& predictors,
- const arma::Row<size_t>& responses,
- const MatType& initialPoint);
+ const arma::Row<size_t>& responses);
/**
* Train the LogisticRegression model with the given instantiated optimizer.
* Using this overload allows configuring the instantiated optimizer before
* training is performed.
*
+ * Note that the initial point of the optimizer
+ * (optimizer.Function().GetInitialPoint()) will be used as the initial point
+ * of the optimization, overwriting any existing trained model. If you don't
+ * want to overwrite the existing model, set
+ * optimizer.Function().GetInitialPoint() to the current parameters vector,
+ * accessible via Parameters().
+ *
* @param optimizer Instantiated optimizer with instantiated error function.
*/
template<
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
index 608dfde..704d95d 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
@@ -23,17 +23,7 @@ LogisticRegression<MatType>::LogisticRegression(
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
lambda(lambda)
{
- LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
- lambda);
- OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
-
- // Train the model.
- Timer::Start("logistic_regression_optimization");
- const double out = optimizer.Optimize(parameters);
- Timer::Stop("logistic_regression_optimization");
-
- Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
- << "trained model is " << out << "." << std::endl;
+ Train<OptimizerType>(predictors, responses);
}
template<typename MatType>
@@ -43,21 +33,10 @@ LogisticRegression<MatType>::LogisticRegression(
const arma::Row<size_t>& responses,
const arma::vec& initialPoint,
const double lambda) :
- parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
+ parameters(initialPoint),
lambda(lambda)
{
- LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
- lambda);
- errorFunction.InitialPoint() = initialPoint;
- OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
-
- // Train the model.
- Timer::Start("logistic_regression_optimization");
- const double out = optimizer.Optimize(parameters);
- Timer::Stop("logistic_regression_optimization");
-
- Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
- << "trained model is " << out << "." << std::endl;
+ Train<OptimizerType>(predictors, responses);
}
template<typename MatType>
@@ -78,6 +57,36 @@ LogisticRegression<MatType>::LogisticRegression(
parameters(optimizer.Function().GetInitialPoint()),
lambda(optimizer.Function().Lambda())
{
+ Train(optimizer);
+}
+
+template<typename MatType>
+template<template<typename> class OptimizerType>
+void LogisticRegression<MatType>::Train(const MatType& predictors,
+ const arma::Row<size_t>& responses)
+{
+ LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
+ lambda);
+ errorFunction.InitialPoint() = parameters;
+ OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
+
+ // Train the model.
+ Timer::Start("logistic_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("logistic_regression_optimization");
+
+ Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
+}
+
+template<typename MatType>
+template<template<typename> class OptimizerType>
+void LogisticRegression<MatType>::Train(
+ OptimizerType<LogisticRegressionFunction<MatType>>& optimizer)
+{
+ // Everything is good. Just train the model.
+ parameters = optimizer.Function().GetInitialPoint();
+
Timer::Start("logistic_regression_optimization");
const double out = optimizer.Optimize(parameters);
Timer::Stop("logistic_regression_optimization");
More information about the mlpack-git
mailing list