[mlpack-svn] r16087 - mlpack/trunk/src/mlpack/methods/logistic_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 28 21:05:47 EST 2013


Author: rcurtin
Date: Sat Dec 28 21:05:47 2013
New Revision: 16087

Log:
Don't hold the optimizer and error function since they are only needed at
construction time.


Modified:
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	Sat Dec 28 21:05:47 2013
@@ -67,9 +67,14 @@
 
   /**
    * Construct a logistic regression model from the given parameters, without
-   * performing any training.
+   * performing any training.  The lambda parameter is used for the
+   * ComputeAccuracy() and ComputeError() functions; this constructor does not
+   * train the model itself.
+   *
+   * @param parameters Parameters making up the model.
+   * @param lambda L2-regularization penalty parameter.
    */
-  LogisticRegression(const arma::vec& parameters);
+  LogisticRegression(const arma::vec& parameters, const double lambda = 0);
 
   //! Return the parameters (the b vector).
   const arma::vec& Parameters() const { return parameters; }
@@ -77,9 +82,9 @@
   arma::vec& Parameters() { return parameters; }
 
   //! Return the lambda value for L2-regularization.
-  const double& Lambda() const { return errorFunction.Lambda(); }
+  const double& Lambda() const { return lambda; }
   //! Modify the lambda value for L2-regularization.
-  double& Lambda() { return errorFunction().Lambda(); }
+  double& Lambda() { return lambda; }
 
   /**
    * Predict the responses to a given set of predictors.  The responses will be
@@ -128,17 +133,8 @@
  private:
   //! Vector of trained parameters.
   arma::vec parameters;
-
-  //! Instantiated error function that will be optimized.
-  LogisticRegressionFunction errorFunction;
-  //! Instantiated optimizer.
-  OptimizerType<LogisticRegressionFunction> optimizer;
-
-  /**
-   * Learn the model by optimizing the logistic regression objective function.
-   * Returns the objective function evaluated when the parameters are optimized.
-   */
-  double LearnModel(const arma::mat& predictors, const arma::vec& responses);
+  //! L2-regularization penalty parameter.
+  double lambda;
 };
 
 }; // namespace regression

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	Sat Dec 28 21:05:47 2013
@@ -20,11 +20,18 @@
     const arma::vec& responses,
     const double lambda) :
     parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
-    errorFunction(LogisticRegressionFunction(predictors, responses, lambda)),
-    optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
+    lambda(lambda)
 {
+  LogisticRegressionFunction errorFunction(predictors, responses, lambda);
+  OptimizerType<LogisticRegressionFunction> optimizer(errorFunction);
+
   // Train the model.
-  LearnModel(predictors, responses);
+  Timer::Start("logistic_regression_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("logistic_regression_optimization");
+
+  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+      << "trained model is " << out << "." << std::endl;
 }
 
 template<template<typename> class OptimizerType>
@@ -34,29 +41,41 @@
     const arma::mat& initialPoint,
     const double lambda) :
     parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
-    errorFunction(LogisticRegressionFunction(predictors, responses)),
-    optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
+    lambda(lambda)
 {
+  LogisticRegressionFunction errorFunction(predictors, responses, lambda);
+  errorFunction.InitialPoint() = initialPoint;
+  OptimizerType<LogisticRegressionFunction> optimizer(errorFunction);
+
   // Train the model.
-  LearnModel(predictors, responses);
+  Timer::Start("logistic_regression_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("logistic_regression_optimization");
+
+  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+      << "trained model is " << out << "." << std::endl;
 }
 
 template<template<typename> class OptimizerType>
 LogisticRegression<OptimizerType>::LogisticRegression(
     OptimizerType<LogisticRegressionFunction>& optimizer) :
     parameters(optimizer.Function().GetInitialPoint()),
-    errorFunction(optimizer.Function()),
-    optimizer(optimizer)
+    lambda(optimizer.Function().Lambda())
 {
   Timer::Start("logistic_regression_optimization");
   const double out = optimizer.Optimize(parameters);
   Timer::Stop("logistic_regression_optimization");
+
+  Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+      << "trained model is " << out << "." << std::endl;
 }
 
 template<template<typename> class OptimizerType>
 LogisticRegression<OptimizerType>::LogisticRegression(
-    const arma::vec& parameters) :
-    parameters(parameters)
+    const arma::vec& parameters,
+    const double lambda) :
+    parameters(parameters),
+    lambda(lambda)
 {
   // Nothing to do.
 }
@@ -81,7 +100,7 @@
 {
   // Construct a new error function.
   LogisticRegressionFunction newErrorFunction(predictors, responses,
-      errorFunction.Lambda());
+      lambda);
 
   return newErrorFunction.Evaluate(parameters);
 }
@@ -105,18 +124,6 @@
   return (double) (count * 100) / responses.n_rows;
 }
 
-template <template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::LearnModel(
-    const arma::mat& predictors,
-    const arma::vec& responses)
-{
-  Timer::Start("logistic_regression_optimization");
-  const double out = optimizer.Optimize(parameters);
-  Timer::Stop("logistic_regression_optimization");
-
-  return out;
-}
-
 }; // namespace regression
 }; // namespace mlpack
 



More information about the mlpack-svn mailing list