[mlpack-git] master: Refactor logistic_regression to allow persistent models. (e67787e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Sep 16 14:29:15 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/bbe9cd161571c99aca88096b07de61341711c049...e67787e336136a9e46b2d502bd583b8aea2668a4

>---------------------------------------------------------------

commit e67787e336136a9e46b2d502bd583b8aea2668a4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 16 05:48:04 2015 +0000

    Refactor logistic_regression to allow persistent models.


>---------------------------------------------------------------

e67787e336136a9e46b2d502bd583b8aea2668a4
 .../logistic_regression_main.cpp                   | 172 ++++++++++-----------
 1 file changed, 78 insertions(+), 94 deletions(-)

diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
index 691e3c2..bf609b0 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
@@ -54,66 +54,62 @@ PROGRAM_INFO("L2-regularized Logistic Regression and Prediction",
     "multi-class case but instead only the two-class case.  Any responses must "
     "be either 0 or 1.");
 
-PARAM_STRING("input_file", "File containing X (predictors).", "i", "");
-PARAM_STRING("input_responses", "Optional file containing y (responses).  If "
-    "not given, the responses are assumed to be the last column of the input "
-    "file.", "r", "");
-
-PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
-    "");
+// Training parameters.
+PARAM_STRING("training_file", "A file containing the training set (the matrix "
+    "of predictors, X).", "t", "");
+PARAM_STRING("labels_file", "A file containing labels (0 or 1) for the points "
+    "in the training set (y).", "l", "");
+
+// Optimizer parameters.
+PARAM_DOUBLE("lambda", "L2-regularization parameter for training.", "L", 0.0);
+PARAM_STRING("optimizer", "Optimizer to use for training ('lbfgs' or 'sgd').",
+    "O", "lbfgs");
+PARAM_DOUBLE("tolerance", "Convergence tolerance for optimizer.", "T", 1e-10);
+PARAM_INT("max_iterations", "Maximum iterations for optimizer (0 indicates no "
+    "limit).", "M", 10000);
+PARAM_DOUBLE("step_size", "Step size for SGD optimizer.", "s", 0.01);
 
-PARAM_STRING("output_file", "File where parameters (b) will be saved.", "o",
+// Model loading/saving.
+PARAM_STRING("input_model", "File containing existing model (parameters).", "i",
     "");
+PARAM_STRING("output_model", "File to save trained logistic regression model "
+    "to.", "m", "");
 
-PARAM_STRING("test_file", "File containing test dataset.", "t", "");
-PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
-    "where the predicted responses will be saved.", "p", "predictions.csv");
+// Testing.
+PARAM_STRING("test_file", "File containing test dataset.", "T", "");
+PARAM_STRING("output_file", "If --test_file is specified, this file is "
+    "where the predicted responses will be saved.", "o", "output.csv");
 PARAM_DOUBLE("decision_boundary", "Decision boundary for prediction; if the "
     "logistic function for a point is less than the boundary, the class is "
     "taken to be 0; otherwise, the class is 1.", "d", 0.5);
 
-PARAM_DOUBLE("lambda", "L2-regularization parameter for training.", "l", 0.0);
-PARAM_STRING("optimizer", "Optimizer to use for training ('lbfgs' or 'sgd').",
-    "O", "lbfgs");
-PARAM_DOUBLE("tolerance", "Convergence tolerance for optimizer.", "T", 1e-10);
-PARAM_INT("max_iterations", "Maximum iterations for optimizer (0 indicates no "
-    "limit).", "M", 0);
-PARAM_DOUBLE("step_size", "Step size for SGD optimizer.", "s", 0.01);
-
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
   // Collect command-line options.
-  const string inputFile = CLI::GetParam<string>("input_file");
-  const string inputResponsesFile = CLI::GetParam<string>("input_responses");
-  const string modelFile = CLI::GetParam<string>("model_file");
-  const string outputFile = CLI::GetParam<string>("output_file");
-  const string testFile = CLI::GetParam<string>("test_file");
-  const string outputPredictionsFile =
-      CLI::GetParam<string>("output_predictions");
+  const string trainingFile = CLI::GetParam<string>("training_file");
+  const string labelsFile = CLI::GetParam<string>("labels_file");
   const double lambda = CLI::GetParam<double>("lambda");
   const string optimizerType = CLI::GetParam<string>("optimizer");
   const double tolerance = CLI::GetParam<double>("tolerance");
+  const double stepSize = CLI::GetParam<double>("step_size");
   const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+  const string inputModelFile = CLI::GetParam<string>("input_model");
+  const string outputModelFile = CLI::GetParam<string>("output_model");
+  const string testFile = CLI::GetParam<string>("test_file");
+  const string outputFile = CLI::GetParam<string>("output_file");
   const double decisionBoundary = CLI::GetParam<double>("decision_boundary");
-  const double stepSize = CLI::GetParam<double>("step_size");
 
   // One of inputFile and modelFile must be specified.
-  if (inputFile.empty() && modelFile.empty())
+  if (trainingFile.empty() && inputModelFile.empty())
     Log::Fatal << "One of --model_file or --input_file must be specified."
         << endl;
 
-  // If they want predictions, they should supply a file to save them to.  This
-  // is only a warning because the program can still work.
-  if (!testFile.empty() && outputPredictionsFile.empty())
-    Log::Warn << "--output_predictions not specified; predictions will not be "
-        << "saved." << endl;
-
   // If no output file is given, the user should know that the model will not be
   // saved, but only if a model is being trained.
-  if (outputFile.empty() && !inputFile.empty())
-    Log::Warn << "--output_file not given; trained model will not be saved."
+  if (outputFile.empty() && !trainingFile.empty())
+    Log::Warn << "--output_model not given; trained model will not be saved."
         << endl;
 
   // Tolerance needs to be positive.
@@ -139,21 +135,39 @@ int main(int argc, char** argv)
     Log::Fatal << "Step size (--step_size) must be positive (received "
         << stepSize << ")." << endl;
 
+  if (CLI::HasParam("step_size") && optimizerType == "lbfgs")
+    Log::Warn << "Step size (--step_size) ignored because 'sgd' optimizer is "
+        << "not being used." << endl;
+
   // These are the matrices we might use.
   arma::mat regressors;
   arma::Mat<size_t> responses;
-  arma::mat model;
   arma::mat testSet;
   arma::Row<size_t> predictions;
 
-  // Load matrices.
-  if (!inputFile.empty())
-    data::Load(inputFile, regressors, true);
+  // Load data matrix.
+  if (!trainingFile.empty())
+    data::Load(trainingFile, regressors, true);
+
+  // Load the model, if necessary.
+  LogisticRegression<> model(0, 0); // Empty model.
+  if (!inputModelFile.empty())
+  {
+    data::Load(inputModelFile, "logistic_regression_model", model);
+  }
+  else
+  {
+    // Set the size of the parameters vector, if necessary.
+    if (labelsFile.empty())
+      model.Parameters() = arma::zeros<arma::vec>(regressors.n_rows);
+    else
+      model.Parameters() = arma::zeros<arma::vec>(regressors.n_rows + 1);
+  }
 
   // Check if the responses are in a separate file.
-  if (!inputResponsesFile.empty())
+  if (!labelsFile.empty())
   {
-    data::Load(inputResponsesFile, responses, true);
+    data::Load(labelsFile, responses, true);
     if (responses.n_rows == 1)
       responses = responses.t();
     if (responses.n_rows != regressors.n_cols)
@@ -168,37 +182,22 @@ int main(int argc, char** argv)
     regressors.shed_row(regressors.n_rows - 1);
   }
 
-  if (!testFile.empty())
-    data::Load(testFile, testSet, true);
-  if (!modelFile.empty())
+  // Now, do the training.
+  if (!trainingFile.empty())
   {
-    data::Load(modelFile, model, true);
-    if (model.n_rows == 1)
-      model = model.t();
-    if ((!regressors.empty()) && (model.n_rows != regressors.n_rows + 1))
-      Log::Fatal << "The model (--model) must have dimensionality of one more "
-          << "than the input dataset (the extra dimension is the intercept)."
-          << endl;
-    if ((!testSet.empty()) && (model.n_rows != testSet.n_rows + 1))
-      Log::Fatal << "The model (--model) must have dimensionality of one more "
-          << "than the test dataset (the extra dimension is the intercept)."
-          << endl;
-  }
-
-  if (!regressors.empty())
-  {
-    // We need to train the model.  Prepare the optimizers.
-    arma::Row<size_t> responsesVec = responses.unsafe_col(0).t();
-    LogisticRegressionFunction<> lrf(regressors, responsesVec, lambda);
-    // Set the initial point, if necessary.
-    if (!model.empty())
+    LogisticRegressionFunction<> lrf(regressors, responses, model.Parameters());
+    if (optimizerType == "sgd")
     {
-      lrf.InitialPoint() = model;
-      Log::Info << "Using model from '" << modelFile << "' as initial model "
-          << "for training." << endl;
-    }
+      SGD<LogisticRegressionFunction<>> sgdOpt(lrf);
+      sgdOpt.MaxIterations() = maxIterations;
+      sgdOpt.Tolerance() = tolerance;
+      sgdOpt.StepSize() = stepSize;
+      Log::Info << "Training model with SGD optimizer." << endl;
 
-    if (optimizerType == "lbfgs")
+      // This will train the model.
+      model.Train(sgdOpt);
+    }
+    else if (optimizerType == "lbfgs")
     {
       L_BFGS<LogisticRegressionFunction<>> lbfgsOpt(lrf);
       lbfgsOpt.MaxIterations() = maxIterations;
@@ -206,43 +205,28 @@ int main(int argc, char** argv)
       Log::Info << "Training model with L-BFGS optimizer." << endl;
 
       // This will train the model.
-      LogisticRegression<L_BFGS> lr(lbfgsOpt);
-      // Extract the newly trained model.
-      model = lr.Parameters();
-    }
-    else if (optimizerType == "sgd")
-    {
-      SGD<LogisticRegressionFunction<>> sgdOpt(lrf);
-      sgdOpt.MaxIterations() = maxIterations;
-      sgdOpt.Tolerance() = tolerance;
-      sgdOpt.StepSize() = stepSize;
-      Log::Info << "Training model with SGD optimizer." << endl;
-
-      // This will train the model.
-      LogisticRegression<SGD> lr(sgdOpt);
-      // Extract the newly trained model.
-      model = lr.Parameters();
+      model.Train(lbfgsOpt);
     }
   }
 
-  if (!testSet.empty())
+  if (!testFile.empty())
   {
+    data::Load(testFile, testSet, true);
+
     // We must perform predictions on the test set.  Training (and the
     // optimizer) are irrelevant here; we'll pass in the model we have.
-    LogisticRegression<> lr(model);
-
     Log::Info << "Predicting classes of points in '" << testFile << "'."
         << endl;
-    lr.Predict(testSet, predictions, decisionBoundary);
+    model.Predict(testSet, predictions, decisionBoundary);
 
     // Save the results, if necessary.
-    if (!outputPredictionsFile.empty())
-      data::Save(outputPredictionsFile, predictions, false);
+    if (!outputFile.empty())
+      data::Save(outputFile, predictions, false);
   }
 
-  if (!outputFile.empty())
+  if (!outputModelFile.empty())
   {
     Log::Info << "Saving model to '" << outputFile << "'." << endl;
-    data::Save(outputFile, model, false);
+    data::Save(outputFile, "logistic_regression_model", model, false);
   }
 }



More information about the mlpack-git mailing list