[mlpack-git] master: Add --output_probabilities_file option. (dca52fd)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 1 14:49:55 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5546ebcf02598c9da06e19ed447e73ddcd0d3347...dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92

>---------------------------------------------------------------

commit dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jun 1 11:49:55 2016 -0700

    Add --output_probabilities_file option.


>---------------------------------------------------------------

dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92
 .../logistic_regression_main.cpp                   | 28 +++++++++++++++++-----
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
index 4f4de07..d2ffd9d 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
@@ -94,7 +94,10 @@ PARAM_STRING("output_model_file", "File to save trained logistic regression "
 // Testing.
 PARAM_STRING("test_file", "File containing test dataset.", "T", "");
 PARAM_STRING("output_file", "If --test_file is specified, this file is "
-    "where the predicted responses will be saved.", "o", "");
+    "where the predictions for the test set will be saved.", "o", "");
+PARAM_STRING("output_probabilities_file", "If --test_file is specified, this "
+    "file is where the class probabilities for the test set will be saved.",
+    "p", "");
 PARAM_DOUBLE("decision_boundary", "Decision boundary for prediction; if the "
     "logistic function for a point is less than the boundary, the class is "
     "taken to be 0; otherwise, the class is 1.", "d", 0.5);
@@ -116,6 +119,8 @@ int main(int argc, char** argv)
   const string outputModelFile = CLI::GetParam<string>("output_model_file");
   const string testFile = CLI::GetParam<string>("test_file");
   const string outputFile = CLI::GetParam<string>("output_file");
+  const string outputProbabilitiesFile =
+      CLI::GetParam<string>("output_probabilities_file");
   const double decisionBoundary = CLI::GetParam<double>("decision_boundary");
 
   // One of inputFile and modelFile must be specified.
@@ -260,13 +265,24 @@ int main(int argc, char** argv)
 
     // We must perform predictions on the test set.  Training (and the
     // optimizer) are irrelevant here; we'll pass in the model we have.
-    Log::Info << "Predicting classes of points in '" << testFile << "'."
-        << endl;
-    model.Predict(testSet, predictions, decisionBoundary);
-
-    // Save the results, if necessary.
     if (!outputFile.empty())
+    {
+      Log::Info << "Predicting classes of points in '" << testFile << "'."
+          << endl;
+      model.Classify(testSet, predictions, decisionBoundary);
+
       data::Save(outputFile, predictions, false);
+    }
+
+    if (!outputProbabilitiesFile.empty())
+    {
+      Log::Info << "Calculating class probabilities of points in '" << testFile
+          << "'." << endl;
+      arma::mat probabilities;
+      model.Classify(testSet, probabilities);
+
+      data::Save(outputProbabilitiesFile, probabilities, false);
+    }
   }
 
   if (!outputModelFile.empty())




More information about the mlpack-git mailing list