[mlpack-svn] r16058 - mlpack/trunk/src/mlpack/methods/logistic_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 21 13:34:37 EST 2013


Author: rcurtin
Date: Thu Nov 21 13:34:37 2013
New Revision: 16058

Log:
Add a constructor that allows passing an instantiated optimizer.


Modified:
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
   mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp	Thu Nov 21 13:34:37 2013
@@ -52,6 +52,19 @@
                      const arma::mat& initialPoint,
                      const double lambda = 0);
 
+  /**
+   * Construct the LogisticRegression class with the given labeled training
+   * data.  This will train the model.  This overload takes an already
+   * instantiated optimizer (which holds the LogisticRegressionFunction error
+   * function, which must also be instantiated), so that the optimizer can be
+   * configured before the training is run by this constructor.  The predictors
+   * and responses and initial point are all taken from the error function
+   * contained in the optimizer.
+   *
+   * @param optimizer Instantiated optimizer with instantiated error function.
+   */
+  LogisticRegression(OptimizerType<LogisticRegressionFunction>& optimizer);
+
   //! Return the parameters (the b vector).
   const arma::vec& Parameters() const { return parameters; }
   //! Modify the parameters (the b vector).

Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp	Thu Nov 21 13:34:37 2013
@@ -45,7 +45,21 @@
   LearnModel();
 }
 
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
+LogisticRegression<OptimizerType>::LogisticRegression(
+    OptimizerType<LogisticRegressionFunction>& optimizer) :
+    predictors(optimizer.Function().Predictors()),
+    responses(optimizer.Function().Responses()),
+    parameters(optimizer.Function().GetInitialPoint()),
+    errorFunction(optimizer.Function()),
+    optimizer(optimizer)
+{
+  Timer::Start("logistic_regression_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("logistic_regression_optimization");
+}
+
+template<template<typename> class OptimizerType>
 void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
                                                 arma::vec& responses,
                                                 const double decisionBoundary)



More information about the mlpack-svn mailing list