[mlpack-git] master: Add option to predict values on test points. (2849148)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Aug 19 12:46:51 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7d64cd61aade77d4607aacf00b36f94369fd7cf9...284914879d42af40b028a7f5a002f4cd82c5fe05
>---------------------------------------------------------------
commit 284914879d42af40b028a7f5a002f4cd82c5fe05
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Aug 19 12:46:36 2015 -0400
Add option to predict values on test points.
>---------------------------------------------------------------
284914879d42af40b028a7f5a002f4cd82c5fe05
src/mlpack/methods/lars/lars_main.cpp | 24 ++++++++++++++++++++++++
1 file changed, 24 insertions(+)
diff --git a/src/mlpack/methods/lars/lars_main.cpp b/src/mlpack/methods/lars/lars_main.cpp
index f473c9e..5875cdf 100644
--- a/src/mlpack/methods/lars/lars_main.cpp
+++ b/src/mlpack/methods/lars/lars_main.cpp
@@ -37,6 +37,11 @@ PARAM_STRING_REQ("responses_file", "File containing y "
PARAM_STRING("output_file", "File to save beta (linear estimator) to.", "o",
"output.csv");
+PARAM_STRING("test_file", "File containing points to regress on (test points).",
+ "t", "");
+PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
+ "where the predicted responses will be saved.", "p", "predictions.csv");
+
PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty.", "l",
0);
PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty.", "L",
@@ -88,4 +93,23 @@ int main(int argc, char* argv[])
const string betaFilename = CLI::GetParam<string>("output_file");
beta.save(betaFilename, raw_ascii);
+
+ if (CLI::HasParam("test_file"))
+ {
+ Log::Info << "Regressing on test points." << endl;
+ const string testFile = CLI::GetParam<string>("test_file");
+ const string outputPredictionsFile =
+ CLI::GetParam<string>("output_predictions");
+
+ // Load test points.
+ mat testPoints;
+ data::Load(testFile, testPoints, true, false);
+
+ arma::vec predictions;
+ lars.Predict(testPoints.t(), predictions, false);
+
+ // Save test predictions. One per line, so, we need a rowvec.
+ arma::rowvec predToSave = predictions.t();
+ data::Save(outputPredictionsFile, predToSave);
+ }
}
More information about the mlpack-git
mailing list