[mlpack-svn] r11499 - mlpack/trunk/src/mlpack/methods/linear_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Feb 13 17:43:36 EST 2012


Author: jcline3
Date: 2012-02-13 17:43:36 -0500 (Mon, 13 Feb 2012)
New Revision: 11499

Modified:
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
Log:
Various improvements to LinearRegression and our executable.

Now supports --model_file, -m, to do prediction on an already computed
model. Also does useful user input checking to make sure the right combin-
ation of options are provided.


Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2012-02-13 22:40:14 UTC (rev 11498)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2012-02-13 22:43:36 UTC (rev 11499)
@@ -66,6 +66,11 @@
 
   // We want to be sure we have the correct number of dimensions in the dataset.
   Log::Assert(nRows == parameters.n_rows - 1);
+  if (nRows != parameters.n_rows -1)
+  {
+    Log::Fatal << "The test data must have the same number of columns as the "
+        "training file.\n";
+  }
 
   predictions.zeros(nCols);
   // We set all the predictions to the intercept value initially.

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2012-02-13 22:40:14 UTC (rev 11498)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2012-02-13 22:43:36 UTC (rev 11499)
@@ -21,11 +21,14 @@
     "and these predicted responses, y', are saved to a file "
     "(--output_predictions).");
 
-PARAM_STRING_REQ("input_file", "File containing X (regressors).", "i");
+PARAM_STRING("input_file", "File containing X (regressors).", "i", "");
 PARAM_STRING("input_responses", "Optional file containing y (responses). If "
     "not given, the responses are assumed to be the last row of the input "
     "file.", "r", "");
 
+PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
+    "");
+
 PARAM_STRING("output_file", "File where parameters (b) will be saved.",
     "o", "parameters.csv");
 
@@ -43,58 +46,110 @@
   // Handle parameters
   CLI::ParseCommandLine(argc, argv);
 
-  const string trainName = CLI::GetParam<string>("input_file");
-  const string testName = CLI::GetParam<string>("test_file");
-  const string responseName = CLI::GetParam<string>("input_responses");
+  const string modelName = CLI::GetParam<string>("model_file");
   const string outputFile = CLI::GetParam<string>("output_file");
   const string outputPredictions = CLI::GetParam<string>("output_predictions");
+  const string responseName = CLI::GetParam<string>("input_responses");
+  const string testName = CLI::GetParam<string>("test_file");
+  const string trainName = CLI::GetParam<string>("input_file");
 
   mat regressors;
   mat responses;
-  data::Load(trainName.c_str(), regressors, true);
 
-  // Are the responses in a separate file?
-  if (responseName == "")
+  LinearRegression lr;
+
+  bool computeModel;
+
+  if (trainName.empty())
   {
-    // The initial predictors for y, Nx1
-    responses = trans(regressors.row(regressors.n_rows - 1));
-    regressors.shed_row(regressors.n_rows - 1);
+    if (modelName.empty())
+    {
+      Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
+      exit(1);
+    }
+    else
+    {
+      computeModel = false;
+    }
   }
+  else if (modelName.empty())
+  {
+    computeModel = true;
+  }
   else
   {
-    // The initial predictors for y, Nx1
-    data::Load(responseName.c_str(), responses, true);
+      Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
+      exit(1);
+  }
 
-    if (responses.n_rows == 1)
-      responses = trans(responses); // Probably loaded backwards, but that's ok.
+  if(!computeModel && testName.empty())
+  {
+    Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
+    exit(1);
+  }
 
-    if (responses.n_cols > 1)
-      Log::Fatal << "The responses must have one column.\n";
+  if (computeModel)
+  {
+    Timer::Start("load_regressors");
+    data::Load(trainName.c_str(), regressors, true);
+    Timer::Stop("load_regressors");
 
-    if (responses.n_rows != regressors.n_cols)
-      Log::Fatal << "The responses must have the same number of rows as the "
-          "training file.\n";
-  }
+    // Are the responses in a separate file?
+    if (responseName.empty())
+    {
+      // The initial predictors for y, Nx1
+      responses = trans(regressors.row(regressors.n_rows - 1));
+      regressors.shed_row(regressors.n_rows - 1);
+    }
+    else
+    {
+      // The initial predictors for y, Nx1
+      Timer::Start("load_responses");
+      data::Load(responseName.c_str(), responses, true);
+      Timer::Stop("load_responses");
 
-  LinearRegression lr(regressors, responses.unsafe_col(0));
+      if (responses.n_rows == 1)
+        responses = trans(responses); // Probably loaded backwards, but that's ok.
 
-  // Save the parameters.
-  data::Save(outputFile.c_str(), lr.Parameters(), false);
+      if (responses.n_cols > 1)
+        Log::Fatal << "The responses must have one column.\n";
 
+      if (responses.n_rows != regressors.n_cols)
+        Log::Fatal << "The responses must have the same number of rows as the "
+            "training file.\n";
+    }
+
+    Timer::Start("regression");
+    lr = LinearRegression(regressors, responses.unsafe_col(0));
+    Timer::Stop("regression");
+
+    // Save the parameters.
+    data::Save(outputFile.c_str(), lr.Parameters(), true);
+  }
+
   // Did we want to predict, too?
-  if (testName != "")
+  if (!testName.empty() )
   {
+
+    if (!computeModel)
+    {
+      Timer::Start("load_model");
+      lr = LinearRegression(modelName);
+      Timer::Stop("load_model");
+    }
+
     arma::mat points;
+    Timer::Stop("load_test_points");
     data::Load(testName.c_str(), points, true);
+    Timer::Stop("load_test_points");
 
-    if (points.n_rows != regressors.n_rows)
-      Log::Fatal << "The test data must have the same number of columns as the "
-          "training file.\n";
-
     arma::vec predictions;
+    Timer::Start("prediction");
     lr.Predict(points, predictions);
+    Timer::Stop("prediction");
 
     // Save predictions.
-    data::Save(outputPredictions.c_str(), predictions, false);
+    predictions = arma::trans(predictions);
+    data::Save(outputPredictions.c_str(), predictions, true);
   }
 }




More information about the mlpack-svn mailing list