[mlpack-git] master: Add and implement Train() methods. (b34bb0f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Sep 16 14:29:13 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/bbe9cd161571c99aca88096b07de61341711c049...e67787e336136a9e46b2d502bd583b8aea2668a4

>---------------------------------------------------------------

commit b34bb0f90d5664e73db082612372722ebc021524
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 16 03:30:45 2015 +0000

    Add and implement Train() methods.


>---------------------------------------------------------------

b34bb0f90d5664e73db082612372722ebc021524
 .../logistic_regression/logistic_regression.hpp    | 10 +++-
 .../logistic_regression_impl.hpp                   | 57 +++++++++++++---------
 2 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/methods/logistic_regression/logistic_regression.hpp b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
index 8ca67f5..a0cdf8f 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression.hpp
@@ -96,14 +96,20 @@ class LogisticRegression
       template<typename> class OptimizerType = mlpack::optimization::L_BFGS
   >
   void Train(const MatType& predictors,
-             const arma::Row<size_t>& responses,
-             const MatType& initialPoint);
+             const arma::Row<size_t>& responses);
 
   /**
    * Train the LogisticRegression model with the given instantiated optimizer.
    * Using this overload allows configuring the instantiated optimizer before
    * training is performed.
    *
+   * Note that the initial point of the optimizer
+   * (optimizer.Function().GetInitialPoint()) will be used as the initial point
+   * of the optimization, overwriting any existing trained model.  If you don't
+   * want to overwrite the existing model, set
+   * optimizer.Function().GetInitialPoint() to the current parameters vector,
+   * accessible via Parameters().
+   *
    * @param optimizer Instantiated optimizer with instantiated error function.
    */
   template<
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
index 608dfde..704d95d 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
@@ -23,17 +23,7 @@ LogisticRegression<MatType>::LogisticRegression(
     parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
     lambda(lambda)
 {
-  LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
-      lambda);
-  OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
-
-  // Train the model.
-  Timer::Start("logistic_regression_optimization");
-  const double out = optimizer.Optimize(parameters);
-  Timer::Stop("logistic_regression_optimization");
-
-  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
-      << "trained model is " << out << "." << std::endl;
+  Train<OptimizerType>(predictors, responses);
 }
 
 template<typename MatType>
@@ -43,21 +33,10 @@ LogisticRegression<MatType>::LogisticRegression(
     const arma::Row<size_t>& responses,
     const arma::vec& initialPoint,
     const double lambda) :
-    parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
+    parameters(initialPoint),
     lambda(lambda)
 {
-  LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
-      lambda);
-  errorFunction.InitialPoint() = initialPoint;
-  OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
-
-  // Train the model.
-  Timer::Start("logistic_regression_optimization");
-  const double out = optimizer.Optimize(parameters);
-  Timer::Stop("logistic_regression_optimization");
-
-  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
-      << "trained model is " << out << "." << std::endl;
+  Train<OptimizerType>(predictors, responses);
 }
 
 template<typename MatType>
@@ -78,6 +57,36 @@ LogisticRegression<MatType>::LogisticRegression(
     parameters(optimizer.Function().GetInitialPoint()),
     lambda(optimizer.Function().Lambda())
 {
+  Train(optimizer);
+}
+
+template<typename MatType>
+template<template<typename> class OptimizerType>
+void LogisticRegression<MatType>::Train(const MatType& predictors,
+                                        const arma::Row<size_t>& responses)
+{
+  LogisticRegressionFunction<MatType> errorFunction(predictors, responses,
+      lambda);
+  errorFunction.InitialPoint() = parameters;
+  OptimizerType<LogisticRegressionFunction<MatType>> optimizer(errorFunction);
+
+  // Train the model.
+  Timer::Start("logistic_regression_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("logistic_regression_optimization");
+
+  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+      << "trained model is " << out << "." << std::endl;
+}
+
+template<typename MatType>
+template<template<typename> class OptimizerType>
+void LogisticRegression<MatType>::Train(
+    OptimizerType<LogisticRegressionFunction<MatType>>& optimizer)
+{
+  // Everything is good.  Just train the model.
+  parameters = optimizer.Function().GetInitialPoint();
+
   Timer::Start("logistic_regression_optimization");
   const double out = optimizer.Optimize(parameters);
   Timer::Stop("logistic_regression_optimization");



More information about the mlpack-git mailing list