[mlpack-git] master: Refactor main executable. (c8cdc77)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:42 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c

>---------------------------------------------------------------

commit c8cdc77478746382a57f7fa8e3365f0176fdf616
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 30 17:25:43 2015 +0000

    Refactor main executable.


>---------------------------------------------------------------

c8cdc77478746382a57f7fa8e3365f0176fdf616
 .../linear_regression/linear_regression_main.cpp   | 39 ++++++++++++----------
 1 file changed, 22 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/linear_regression/linear_regression_main.cpp b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
index 4cdc9c8..9f00850 100644
--- a/src/mlpack/methods/linear_regression/linear_regression_main.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
@@ -25,18 +25,17 @@ PROGRAM_INFO("Simple Linear Regression and Prediction",
     "(--output_predictions).  This type of regression is related to least-angle"
     " regression, which mlpack implements with the 'lars' executable.");
 
-PARAM_STRING("input_file", "File containing X (regressors).", "i", "");
-PARAM_STRING("input_responses", "Optional file containing y (responses). If "
+PARAM_STRING("training_file", "File containing training set X (regressors).",
+    "t", "");
+PARAM_STRING("training_responses", "Optional file containing y (responses). If "
     "not given, the responses are assumed to be the last row of the input "
     "file.", "r", "");
 
-PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
-    "");
+PARAM_STRING("input_model_file", "File containing existing model (parameters).",
+    "m", "");
+PARAM_STRING("output_model_file", "File to save trained model to.", "M", "");
 
-PARAM_STRING("output_file", "File where parameters (b) will be saved.",
-    "o", "parameters.csv");
-
-PARAM_STRING("test_file", "File containing X' (test regressors).", "t", "");
+PARAM_STRING("test_file", "File containing X' (test regressors).", "T", "");
 PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
     "where the predicted responses will be saved.", "p", "predictions.csv");
 
@@ -50,15 +49,15 @@ using namespace std;
 
 int main(int argc, char* argv[])
 {
-  // Handle parameters
+  // Handle parameters.
   CLI::ParseCommandLine(argc, argv);
 
-  const string modelName = CLI::GetParam<string>("model_file");
-  const string outputFile = CLI::GetParam<string>("output_file");
+  const string inputModelFile = CLI::GetParam<string>("input_model_file");
+  const string outputModelFile = CLI::GetParam<string>("output_model_file");
   const string outputPredictions = CLI::GetParam<string>("output_predictions");
-  const string responseName = CLI::GetParam<string>("input_responses");
+  const string responseName = CLI::GetParam<string>("training_responses");
   const string testName = CLI::GetParam<string>("test_file");
-  const string trainName = CLI::GetParam<string>("input_file");
+  const string trainName = CLI::GetParam<string>("training_file");
   const double lambda = CLI::GetParam<double>("lambda");
 
   mat regressors;
@@ -72,14 +71,14 @@ int main(int argc, char* argv[])
   // We want to determine if an input file XOR model file were given.
   if (trainName.empty()) // The user specified no input file.
   {
-    if (modelName.empty()) // The user specified no model file, error and exit.
+    if (inputModelFile.empty()) // The user specified no model file; error.
       Log::Fatal << "You must specify either --input_file or --model_file."
           << endl;
     else // The model file was specified, no problems.
       computeModel = false;
   }
   // The user specified an input file but no model file, no problems.
-  else if (modelName.empty())
+  else if (inputModelFile.empty())
     computeModel = true;
   // The user specified both an input file and model file.
   // This is ambiguous -- which model should we use? A generated one or given
@@ -98,6 +97,11 @@ int main(int argc, char* argv[])
         << "--test_file." << endl;
   }
 
+  if (!computeModel && CLI::HasParam("lambda"))
+  {
+    Log::Warn << "--lambda ignored because no model is being trained." << endl;
+  }
+
   // An input file was given and we need to generate the model.
   if (computeModel)
   {
@@ -135,7 +139,8 @@ int main(int argc, char* argv[])
     Timer::Stop("regression");
 
     // Save the parameters.
-    data::Save(outputFile, lr.Parameters(), true);
+    if (!outputModelFile.empty())
+      data::Save(outputModelFile, "linearRegressionModel", lr);
   }
 
   // Did we want to predict, too?
@@ -145,7 +150,7 @@ int main(int argc, char* argv[])
     if (!computeModel)
     {
       Timer::Start("load_model");
-      //lr = LinearRegression(modelName);
+      data::Load(inputModelFile, "linearRegressionModel", lr, true);
       Timer::Stop("load_model");
     }
 



More information about the mlpack-git mailing list