[mlpack-git] master: Add --output_probabilities_file option. (dca52fd)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 1 14:49:55 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5546ebcf02598c9da06e19ed447e73ddcd0d3347...dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92
>---------------------------------------------------------------
commit dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jun 1 11:49:55 2016 -0700
Add --output_probabilities_file option.
>---------------------------------------------------------------
dca52fd2ed7a7f44c4fbd7b0f89e4c5bf2337b92
.../logistic_regression_main.cpp | 28 +++++++++++++++++-----
1 file changed, 22 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
index 4f4de07..d2ffd9d 100644
--- a/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
+++ b/src/mlpack/methods/logistic_regression/logistic_regression_main.cpp
@@ -94,7 +94,10 @@ PARAM_STRING("output_model_file", "File to save trained logistic regression "
// Testing.
PARAM_STRING("test_file", "File containing test dataset.", "T", "");
PARAM_STRING("output_file", "If --test_file is specified, this file is "
- "where the predicted responses will be saved.", "o", "");
+ "where the predictions for the test set will be saved.", "o", "");
+PARAM_STRING("output_probabilities_file", "If --test_file is specified, this "
+ "file is where the class probabilities for the test set will be saved.",
+ "p", "");
PARAM_DOUBLE("decision_boundary", "Decision boundary for prediction; if the "
"logistic function for a point is less than the boundary, the class is "
"taken to be 0; otherwise, the class is 1.", "d", 0.5);
@@ -116,6 +119,8 @@ int main(int argc, char** argv)
const string outputModelFile = CLI::GetParam<string>("output_model_file");
const string testFile = CLI::GetParam<string>("test_file");
const string outputFile = CLI::GetParam<string>("output_file");
+ const string outputProbabilitiesFile =
+ CLI::GetParam<string>("output_probabilities_file");
const double decisionBoundary = CLI::GetParam<double>("decision_boundary");
// One of inputFile and modelFile must be specified.
@@ -260,13 +265,24 @@ int main(int argc, char** argv)
// We must perform predictions on the test set. Training (and the
// optimizer) are irrelevant here; we'll pass in the model we have.
- Log::Info << "Predicting classes of points in '" << testFile << "'."
- << endl;
- model.Predict(testSet, predictions, decisionBoundary);
-
- // Save the results, if necessary.
if (!outputFile.empty())
+ {
+ Log::Info << "Predicting classes of points in '" << testFile << "'."
+ << endl;
+ model.Classify(testSet, predictions, decisionBoundary);
+
data::Save(outputFile, predictions, false);
+ }
+
+ if (!outputProbabilitiesFile.empty())
+ {
+ Log::Info << "Calculating class probabilities of points in '" << testFile
+ << "'." << endl;
+ arma::mat probabilities;
+ model.Classify(testSet, probabilities);
+
+ data::Save(outputProbabilitiesFile, probabilities, false);
+ }
}
if (!outputModelFile.empty())
More information about the mlpack-git
mailing list