[mlpack-svn] r16083 - mlpack/trunk/src/mlpack/methods/logistic_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 11 19:17:19 EST 2013
Author: rcurtin
Date: Wed Dec 11 19:17:19 2013
New Revision: 16083
Log:
Remove predictors and responses because they don't need to be stored by the
LogisticRegression class, and then add a constructor so you can pass a given set
of parameters.
Modified:
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression.hpp Wed Dec 11 19:17:19 2013
@@ -65,6 +65,12 @@
*/
LogisticRegression(OptimizerType<LogisticRegressionFunction>& optimizer);
+ /**
+ * Construct a logistic regression model from the given parameters, without
+ * performing any training.
+ */
+ LogisticRegression(const arma::vec& parameters);
+
//! Return the parameters (the b vector).
const arma::vec& Parameters() const { return parameters; }
//! Modify the parameters (the b vector).
@@ -120,10 +126,6 @@
const arma::vec& responses) const;
private:
- //! Matrix of predictor points (X).
- const arma::mat& predictors;
- //! Vector of responses (y).
- const arma::vec& responses;
//! Vector of trained parameters.
arma::vec parameters;
@@ -136,7 +138,7 @@
* Learn the model by optimizing the logistic regression objective function.
* Returns the objective function evaluated when the parameters are optimized.
*/
- double LearnModel();
+ double LearnModel(const arma::mat& predictors, const arma::vec& responses);
};
}; // namespace regression
Modified: mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp Wed Dec 11 19:17:19 2013
@@ -19,14 +19,12 @@
const arma::mat& predictors,
const arma::vec& responses,
const double lambda) :
- predictors(predictors),
- responses(responses),
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
errorFunction(LogisticRegressionFunction(predictors, responses, lambda)),
optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
{
// Train the model.
- LearnModel();
+ LearnModel(predictors, responses);
}
template<template<typename> class OptimizerType>
@@ -35,21 +33,17 @@
const arma::vec& responses,
const arma::mat& initialPoint,
const double lambda) :
- predictors(predictors),
- responses(responses),
parameters(arma::zeros<arma::vec>(predictors.n_rows + 1)),
errorFunction(LogisticRegressionFunction(predictors, responses)),
optimizer(OptimizerType<LogisticRegressionFunction>(errorFunction))
{
// Train the model.
- LearnModel();
+ LearnModel(predictors, responses);
}
template<template<typename> class OptimizerType>
LogisticRegression<OptimizerType>::LogisticRegression(
OptimizerType<LogisticRegressionFunction>& optimizer) :
- predictors(optimizer.Function().Predictors()),
- responses(optimizer.Function().Responses()),
parameters(optimizer.Function().GetInitialPoint()),
errorFunction(optimizer.Function()),
optimizer(optimizer)
@@ -60,6 +54,14 @@
}
template<template<typename> class OptimizerType>
+LogisticRegression<OptimizerType>::LogisticRegression(
+ const arma::vec& parameters) :
+ parameters(parameters)
+{
+ // Nothing to do.
+}
+
+template<template<typename> class OptimizerType>
void LogisticRegression<OptimizerType>::Predict(const arma::mat& predictors,
arma::vec& responses,
const double decisionBoundary)
@@ -104,7 +106,9 @@
}
template <template<typename> class OptimizerType>
-double LogisticRegression<OptimizerType>::LearnModel()
+double LogisticRegression<OptimizerType>::LearnModel(
+ const arma::mat& predictors,
+ const arma::vec& responses)
{
Timer::Start("logistic_regression_optimization");
const double out = optimizer.Optimize(parameters);
More information about the mlpack-svn
mailing list