[mlpack-git] master: Add option to predict values on test points. (2849148)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Aug 19 12:46:51 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7d64cd61aade77d4607aacf00b36f94369fd7cf9...284914879d42af40b028a7f5a002f4cd82c5fe05

>---------------------------------------------------------------

commit 284914879d42af40b028a7f5a002f4cd82c5fe05
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Aug 19 12:46:36 2015 -0400

    Add option to predict values on test points.


>---------------------------------------------------------------

284914879d42af40b028a7f5a002f4cd82c5fe05
 src/mlpack/methods/lars/lars_main.cpp | 24 ++++++++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/src/mlpack/methods/lars/lars_main.cpp b/src/mlpack/methods/lars/lars_main.cpp
index f473c9e..5875cdf 100644
--- a/src/mlpack/methods/lars/lars_main.cpp
+++ b/src/mlpack/methods/lars/lars_main.cpp
@@ -37,6 +37,11 @@ PARAM_STRING_REQ("responses_file", "File containing y "
 PARAM_STRING("output_file", "File to save beta (linear estimator) to.", "o",
     "output.csv");
 
+PARAM_STRING("test_file", "File containing points to regress on (test points).",
+    "t", "");
+PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
+    "where the predicted responses will be saved.", "p", "predictions.csv");
+
 PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty.", "l",
     0);
 PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty.", "L",
@@ -88,4 +93,23 @@ int main(int argc, char* argv[])
 
   const string betaFilename = CLI::GetParam<string>("output_file");
   beta.save(betaFilename, raw_ascii);
+
+  if (CLI::HasParam("test_file"))
+  {
+    Log::Info << "Regressing on test points." << endl;
+    const string testFile = CLI::GetParam<string>("test_file");
+    const string outputPredictionsFile =
+        CLI::GetParam<string>("output_predictions");
+
+    // Load test points.
+    mat testPoints;
+    data::Load(testFile, testPoints, true, false);
+
+    arma::vec predictions;
+    lars.Predict(testPoints.t(), predictions, false);
+
+    // Save test predictions.  One per line, so, we need a rowvec.
+    arma::rowvec predToSave = predictions.t();
+    data::Save(outputPredictionsFile, predToSave);
+  }
 }



More information about the mlpack-git mailing list