[mlpack-svn] r16058 - mlpack/trunk/src/mlpack/methods/logistic_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 21 13:34:37 EST 2013
Author: rcurtin
Date: Thu Nov 21 13:34:37 2013
New Revision: 16058
Log:
Add a constructor that allows passing an instantiated optimizer.
Modified:
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp Thu Nov 21 13:34:37 2013
@@ -52,6 +52,19 @@
const arma::mat& initialPoint,
const double lambda = 0);
+ /**
+ * Construct the LogisticRegression class with the given labeled training
+ * data. This will train the model. This overload takes an already
+ * instantiated optimizer (which holds the LogisticRegressionFunction error
+ * function, which must also be instantiated), so that the optimizer can be
+ * configured before the training is run by this constructor. The predictors
+ * and responses and initial point are all taken from the error function
+ * contained in the optimizer.
+ *
+ * @param optimizer Instantiated optimizer with instantiated error function.
+ */
+ LogisticRegression(OptimizerType<LogisticRegressionFunction>& optimizer);
+
//! Return the parameters (the b vector).
const arma::vec& Parameters() const { return parameters; }
//! Modify the parameters (the b vector).
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp Thu Nov 21 13:34:37 2013
@@ -45,7 +45,21 @@
LearnModel();
}
-template <template<typename> class OptimizerType>
+template<template<typename> class OptimizerType>
+LogisticRegression<OptimizerType>::LogisticRegression(
+ OptimizerType<LogisticRegressionFunction>& optimizer) :
+ predictors(optimizer.Function().Predictors()),
+ responses(optimizer.Function().Responses()),
+ parameters(optimizer.Function().GetInitialPoint()),
+ errorFunction(optimizer.Function()),
+ optimizer(optimizer)
+{
+ Timer::Start("logistic_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("logistic_regression_optimization");
+}
+
+template<template<typename> class OptimizerType>
void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
arma::vec& responses,
const double decisionBoundary)
More information about the mlpack-svn
mailing list