[mlpack-svn] r16083 - mlpack/trunk/src/mlpack/methods/logistic_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 11 19:17:19 EST 2013


Author: rcurtin
Date: Wed Dec 11 19:17:19 2013
New Revision: 16083

Log:
Remove predictors and responses because they don't need to be stored by the
LogisticRegression class, and then add a constructor so you can pass a given set
of parameters.


Modified:
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	Wed Dec 11 19:17:19 2013
@@ -65,6 +65,12 @@
    */
   LogisticRegression(OptimizerType<LogisticRegressionFunction>& optimizer);
 
+  /**
+   * Construct a logistic regression model from the given parameters, without
+   * performing any training.
+   */
+  LogisticRegression(const arma::vec& parameters);
+
   //! Return the parameters (the b vector).
   const arma::vec& Parameters() const { return parameters; }
   //! Modify the parameters (the b vector).
@@ -120,10 +126,6 @@
                       const arma::vec& responses) const;
 
  private:
-  //! Matrix of predictor points (X).
-  const arma::mat& predictors;
-  //! Vector of responses (y).
-  const arma::vec& responses;
   //! Vector of trained parameters.
   arma::vec parameters;
 
@@ -136,7 +138,7 @@
    * Learn the model by optimizing the logistic regression objective function.
    * Returns the objective function evaluated when the parameters are optimized.
    */
-  double LearnModel();
+  double LearnModel(const arma::mat& predictors, const arma::vec& responses);
 };
 
 }; // namespace regression

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	Wed Dec 11 19:17:19 2013
@@ -19,14 +19,12 @@
     const arma::mat& predictors,
     const arma::vec& responses,
     const double lambda) :
-    predictors(predictors),
-    responses(responses),
     parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
     errorFunction(LogisticRegressionFunction(predictors, responses, lambda)),
     optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
 {
   // Train the model.
-  LearnModel();
+  LearnModel(predictors, responses);
 }
 
 template<template<typename> class OptimizerType>
@@ -35,21 +33,17 @@
     const arma::vec& responses,
     const arma::mat& initialPoint,
     const double lambda) :
-    predictors(predictors),
-    responses(responses),
     parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
     errorFunction(LogisticRegressionFunction(predictors, responses)),
     optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
 {
   // Train the model.
-  LearnModel();
+  LearnModel(predictors, responses);
 }
 
 template<template<typename> class OptimizerType>
 LogisticRegression<OptimizerType>::LogisticRegression(
     OptimizerType<LogisticRegressionFunction>& optimizer) :
-    predictors(optimizer.Function().Predictors()),
-    responses(optimizer.Function().Responses()),
     parameters(optimizer.Function().GetInitialPoint()),
     errorFunction(optimizer.Function()),
     optimizer(optimizer)
@@ -60,6 +54,14 @@
 }
 
 template<template<typename> class OptimizerType>
+LogisticRegression<OptimizerType>::LogisticRegression(
+    const arma::vec& parameters) :
+    parameters(parameters)
+{
+  // Nothing to do.
+}
+
+template<template<typename> class OptimizerType>
 void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
                                                 arma::vec& responses,
                                                 const double decisionBoundary)
@@ -104,7 +106,9 @@
 }
 
 template <template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::LearnModel()
+double LogisticRegression<OptimizerType>::LearnModel(
+    const arma::mat& predictors,
+    const arma::vec& responses)
 {
   Timer::Start("logistic_regression_optimization");
   const double out = optimizer.Optimize(parameters);



More information about the mlpack-svn mailing list