[mlpack-git] master: Refactor logistic_regression to allow persistent models. (e67787e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Sep 16 14:29:15 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/bbe9cd161571c99aca88096b07de61341711c049...e67787e336136a9e46b2d502bd583b8aea2668a4
>---------------------------------------------------------------
commit e67787e336136a9e46b2d502bd583b8aea2668a4
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 16 05:48:04 2015 +0000
Refactor logistic_regression to allow persistent models.
>---------------------------------------------------------------
e67787e336136a9e46b2d502bd583b8aea2668a4
.../logistic_regression_main.cpp | 172 ++++++++++-----------
1 file changed, 78 insertions(+), 94 deletions(-)
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
index 691e3c2..bf609b0 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
@@ -54,66 +54,62 @@ PROGRAM_INFO("L2-regularized Logistic Regression and Prediction",
"multi-class case but instead only the two-class case. Any responses must "
"be either 0 or 1.");
-PARAM_STRING("input_file", "File containing X (predictors).", "i", "");
-PARAM_STRING("input_responses", "Optional file containing y (responses). If "
- "not given, the responses are assumed to be the last column of the input "
- "file.", "r", "");
-
-PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
- "");
+// Training parameters.
+PARAM_STRING("training_file", "A file containing the training set (the matrix "
+ "of predictors, X).", "t", "");
+PARAM_STRING("labels_file", "A file containing labels (0 or 1) for the points "
+ "in the training set (y).", "l", "");
+
+// Optimizer parameters.
+PARAM_DOUBLE("lambda", "L2-regularization parameter for training.", "L", 0.0);
+PARAM_STRING("optimizer", "Optimizer to use for training ('lbfgs' or 'sgd').",
+ "O", "lbfgs");
+PARAM_DOUBLE("tolerance", "Convergence tolerance for optimizer.", "T", 1e-10);
+PARAM_INT("max_iterations", "Maximum iterations for optimizer (0 indicates no "
+ "limit).", "M", 10000);
+PARAM_DOUBLE("step_size", "Step size for SGD optimizer.", "s", 0.01);
-PARAM_STRING("output_file", "File where parameters (b) will be saved.", "o",
+// Model loading/saving.
+PARAM_STRING("input_model", "File containing existing model (parameters).", "i",
"");
+PARAM_STRING("output_model", "File to save trained logistic regression model "
+ "to.", "m", "");
-PARAM_STRING("test_file", "File containing test dataset.", "t", "");
-PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
- "where the predicted responses will be saved.", "p", "predictions.csv");
+// Testing.
+PARAM_STRING("test_file", "File containing test dataset.", "T", "");
+PARAM_STRING("output_file", "If --test_file is specified, this file is "
+ "where the predicted responses will be saved.", "o", "output.csv");
PARAM_DOUBLE("decision_boundary", "Decision boundary for prediction; if the "
"logistic function for a point is less than the boundary, the class is "
"taken to be 0; otherwise, the class is 1.", "d", 0.5);
-PARAM_DOUBLE("lambda", "L2-regularization parameter for training.", "l", 0.0);
-PARAM_STRING("optimizer", "Optimizer to use for training ('lbfgs' or 'sgd').",
- "O", "lbfgs");
-PARAM_DOUBLE("tolerance", "Convergence tolerance for optimizer.", "T", 1e-10);
-PARAM_INT("max_iterations", "Maximum iterations for optimizer (0 indicates no "
- "limit).", "M", 0);
-PARAM_DOUBLE("step_size", "Step size for SGD optimizer.", "s", 0.01);
-
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
// Collect command-line options.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string inputResponsesFile = CLI::GetParam<string>("input_responses");
- const string modelFile = CLI::GetParam<string>("model_file");
- const string outputFile = CLI::GetParam<string>("output_file");
- const string testFile = CLI::GetParam<string>("test_file");
- const string outputPredictionsFile =
- CLI::GetParam<string>("output_predictions");
+ const string trainingFile = CLI::GetParam<string>("training_file");
+ const string labelsFile = CLI::GetParam<string>("labels_file");
const double lambda = CLI::GetParam<double>("lambda");
const string optimizerType = CLI::GetParam<string>("optimizer");
const double tolerance = CLI::GetParam<double>("tolerance");
+ const double stepSize = CLI::GetParam<double>("step_size");
const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+ const string inputModelFile = CLI::GetParam<string>("input_model");
+ const string outputModelFile = CLI::GetParam<string>("output_model");
+ const string testFile = CLI::GetParam<string>("test_file");
+ const string outputFile = CLI::GetParam<string>("output_file");
const double decisionBoundary = CLI::GetParam<double>("decision_boundary");
- const double stepSize = CLI::GetParam<double>("step_size");
// One of inputFile and modelFile must be specified.
- if (inputFile.empty() && modelFile.empty())
+ if (trainingFile.empty() && inputModelFile.empty())
Log::Fatal << "One of --model_file or --input_file must be specified."
<< endl;
- // If they want predictions, they should supply a file to save them to. This
- // is only a warning because the program can still work.
- if (!testFile.empty() && outputPredictionsFile.empty())
- Log::Warn << "--output_predictions not specified; predictions will not be "
- << "saved." << endl;
-
// If no output file is given, the user should know that the model will not be
// saved, but only if a model is being trained.
- if (outputFile.empty() && !inputFile.empty())
- Log::Warn << "--output_file not given; trained model will not be saved."
+ if (outputFile.empty() && !trainingFile.empty())
+ Log::Warn << "--output_model not given; trained model will not be saved."
<< endl;
// Tolerance needs to be positive.
@@ -139,21 +135,39 @@ int main(int argc, char** argv)
Log::Fatal << "Step size (--step_size) must be positive (received "
<< stepSize << ")." << endl;
+ if (CLI::HasParam("step_size") && optimizerType == "lbfgs")
+ Log::Warn << "Step size (--step_size) ignored because 'sgd' optimizer is "
+ << "not being used." << endl;
+
// These are the matrices we might use.
arma::mat regressors;
arma::Mat<size_t> responses;
- arma::mat model;
arma::mat testSet;
arma::Row<size_t> predictions;
- // Load matrices.
- if (!inputFile.empty())
- data::Load(inputFile, regressors, true);
+ // Load data matrix.
+ if (!trainingFile.empty())
+ data::Load(trainingFile, regressors, true);
+
+ // Load the model, if necessary.
+ LogisticRegression<> model(0, 0); // Empty model.
+ if (!inputModelFile.empty())
+ {
+ data::Load(inputModelFile, "logistic_regression_model", model);
+ }
+ else
+ {
+ // Set the size of the parameters vector, if necessary.
+ if (labelsFile.empty())
+ model.Parameters() = arma::zeros<arma::vec>(regressors.n_rows);
+ else
+ model.Parameters() = arma::zeros<arma::vec>(regressors.n_rows + 1);
+ }
// Check if the responses are in a separate file.
- if (!inputResponsesFile.empty())
+ if (!labelsFile.empty())
{
- data::Load(inputResponsesFile, responses, true);
+ data::Load(labelsFile, responses, true);
if (responses.n_rows == 1)
responses = responses.t();
if (responses.n_rows != regressors.n_cols)
@@ -168,37 +182,22 @@ int main(int argc, char** argv)
regressors.shed_row(regressors.n_rows - 1);
}
- if (!testFile.empty())
- data::Load(testFile, testSet, true);
- if (!modelFile.empty())
+ // Now, do the training.
+ if (!trainingFile.empty())
{
- data::Load(modelFile, model, true);
- if (model.n_rows == 1)
- model = model.t();
- if ((!regressors.empty()) && (model.n_rows != regressors.n_rows + 1))
- Log::Fatal << "The model (--model) must have dimensionality of one more "
- << "than the input dataset (the extra dimension is the intercept)."
- << endl;
- if ((!testSet.empty()) && (model.n_rows != testSet.n_rows + 1))
- Log::Fatal << "The model (--model) must have dimensionality of one more "
- << "than the test dataset (the extra dimension is the intercept)."
- << endl;
- }
-
- if (!regressors.empty())
- {
- // We need to train the model. Prepare the optimizers.
- arma::Row<size_t> responsesVec = responses.unsafe_col(0).t();
- LogisticRegressionFunction<> lrf(regressors, responsesVec, lambda);
- // Set the initial point, if necessary.
- if (!model.empty())
+ LogisticRegressionFunction<> lrf(regressors, responses, model.Parameters());
+ if (optimizerType == "sgd")
{
- lrf.InitialPoint() = model;
- Log::Info << "Using model from '" << modelFile << "' as initial model "
- << "for training." << endl;
- }
+ SGD<LogisticRegressionFunction<>> sgdOpt(lrf);
+ sgdOpt.MaxIterations() = maxIterations;
+ sgdOpt.Tolerance() = tolerance;
+ sgdOpt.StepSize() = stepSize;
+ Log::Info << "Training model with SGD optimizer." << endl;
- if (optimizerType == "lbfgs")
+ // This will train the model.
+ model.Train(sgdOpt);
+ }
+ else if (optimizerType == "lbfgs")
{
L_BFGS<LogisticRegressionFunction<>> lbfgsOpt(lrf);
lbfgsOpt.MaxIterations() = maxIterations;
@@ -206,43 +205,28 @@ int main(int argc, char** argv)
Log::Info << "Training model with L-BFGS optimizer." << endl;
// This will train the model.
- LogisticRegression<L_BFGS> lr(lbfgsOpt);
- // Extract the newly trained model.
- model = lr.Parameters();
- }
- else if (optimizerType == "sgd")
- {
- SGD<LogisticRegressionFunction<>> sgdOpt(lrf);
- sgdOpt.MaxIterations() = maxIterations;
- sgdOpt.Tolerance() = tolerance;
- sgdOpt.StepSize() = stepSize;
- Log::Info << "Training model with SGD optimizer." << endl;
-
- // This will train the model.
- LogisticRegression<SGD> lr(sgdOpt);
- // Extract the newly trained model.
- model = lr.Parameters();
+ model.Train(lbfgsOpt);
}
}
- if (!testSet.empty())
+ if (!testFile.empty())
{
+ data::Load(testFile, testSet, true);
+
// We must perform predictions on the test set. Training (and the
// optimizer) are irrelevant here; we'll pass in the model we have.
- LogisticRegression<> lr(model);
-
Log::Info << "Predicting classes of points in '" << testFile << "'."
<< endl;
- lr.Predict(testSet, predictions, decisionBoundary);
+ model.Predict(testSet, predictions, decisionBoundary);
// Save the results, if necessary.
- if (!outputPredictionsFile.empty())
- data::Save(outputPredictionsFile, predictions, false);
+ if (!outputFile.empty())
+ data::Save(outputFile, predictions, false);
}
- if (!outputFile.empty())
+ if (!outputModelFile.empty())
{
Log::Info << "Saving model to '" << outputFile << "'." << endl;
- data::Save(outputFile, model, false);
+ data::Save(outputFile, "logistic_regression_model", model, false);
}
}
More information about the mlpack-git
mailing list