[mlpack-svn] r16087 - mlpack/trunk/src/mlpack/methods/logistic_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 28 21:05:47 EST 2013
Author: rcurtin
Date: Sat Dec 28 21:05:47 2013
New Revision: 16087
Log:
Don't hold the optimizer and error function since they are only needed at
construction time.
Modified:
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp Sat Dec 28 21:05:47 2013
@@ -67,9 +67,14 @@
/**
* Construct a logistic regression model from the given parameters, without
- * performing any training.
+ * performing any training. The lambda parameter is used for the
+ * ComputeAccuracy() and ComputeError() functions; this constructor does not
+ * train the model itself.
+ *
+ * @param parameters Parameters making up the model.
+ * @param lambda L2-regularization penalty parameter.
*/
- LogisticRegression(const arma::vec& parameters);
+ LogisticRegression(const arma::vec& parameters, const double lambda = 0);
//! Return the parameters (the b vector).
const arma::vec& Parameters() const { return parameters; }
@@ -77,9 +82,9 @@
arma::vec& Parameters() { return parameters; }
//! Return the lambda value for L2-regularization.
- const double& Lambda() const { return errorFunction.Lambda(); }
+ const double& Lambda() const { return lambda; }
//! Modify the lambda value for L2-regularization.
- double& Lambda() { return errorFunction().Lambda(); }
+ double& Lambda() { return lambda; }
/**
* Predict the responses to a given set of predictors. The responses will be
@@ -128,17 +133,8 @@
private:
//! Vector of trained parameters.
arma::vec parameters;
-
- //! Instantiated error function that will be optimized.
- LogisticRegressionFunction errorFunction;
- //! Instantiated optimizer.
- OptimizerType<LogisticRegressionFunction> optimizer;
-
- /**
- * Learn the model by optimizing the logistic regression objective function.
- * Returns the objective function evaluated when the parameters are optimized.
- */
- double LearnModel(const arma::mat& predictors, const arma::vec& responses);
+ //! L2-regularization penalty parameter.
+ double lambda;
};
}; // namespace regression
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp Sat Dec 28 21:05:47 2013
@@ -20,11 +20,18 @@
const arma::vec& responses,
const double lambda) :
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
- errorFunction(LogisticRegressionFunction(predictors, responses, lambda)),
- optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
+ lambda(lambda)
{
+ LogisticRegressionFunction errorFunction(predictors, responses, lambda);
+ OptimizerType<LogisticRegressionFunction> optimizer(errorFunction);
+
// Train the model.
- LearnModel(predictors, responses);
+ Timer::Start("logistic_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("logistic_regression_optimization");
+
+ Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
}
template<template<typename> class OptimizerType>
@@ -34,29 +41,41 @@
const arma::mat& initialPoint,
const double lambda) :
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
- errorFunction(LogisticRegressionFunction(predictors, responses)),
- optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
+ lambda(lambda)
{
+ LogisticRegressionFunction errorFunction(predictors, responses, lambda);
+ errorFunction.InitialPoint() = initialPoint;
+ OptimizerType<LogisticRegressionFunction> optimizer(errorFunction);
+
// Train the model.
- LearnModel(predictors, responses);
+ Timer::Start("logistic_regression_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("logistic_regression_optimization");
+
+ Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
}
template<template<typename> class OptimizerType>
LogisticRegression<OptimizerType>::LogisticRegression(
OptimizerType<LogisticRegressionFunction>& optimizer) :
parameters(optimizer.Function().GetInitialPoint()),
- errorFunction(optimizer.Function()),
- optimizer(optimizer)
+ lambda(optimizer.Function().Lambda())
{
Timer::Start("logistic_regression_optimization");
const double out = optimizer.Optimize(parameters);
Timer::Stop("logistic_regression_optimization");
+
+ Log::Info << "LogisticRegression::LogisticRegression(): final objective of "
+ << "trained model is " << out << "." << std::endl;
}
template<template<typename> class OptimizerType>
LogisticRegression<OptimizerType>::LogisticRegression(
- const arma::vec& parameters) :
- parameters(parameters)
+ const arma::vec& parameters,
+ const double lambda) :
+ parameters(parameters),
+ lambda(lambda)
{
// Nothing to do.
}
@@ -81,7 +100,7 @@
{
// Construct a new error function.
LogisticRegressionFunction newErrorFunction(predictors, responses,
- errorFunction.Lambda());
+ lambda);
return newErrorFunction.Evaluate(parameters);
}
@@ -105,18 +124,6 @@
return (double) (count * 100) / responses.n_rows;
}
-template <template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::LearnModel(
- const arma::mat& predictors,
- const arma::vec& responses)
-{
- Timer::Start("logistic_regression_optimization");
- const double out = optimizer.Optimize(parameters);
- Timer::Stop("logistic_regression_optimization");
-
- return out;
-}
-
}; // namespace regression
}; // namespace mlpack
More information about the mlpack-svn
mailing list