[mlpack-git] master: Refactor main program, add documentation, slightly improve functionality. (a1dada8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Nov 19 16:17:33 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/319205b2f3103187c584db302b1a3683aa2fbfdf...a1dada8ba0f88f14653f09ea8c3ec8b04d982434
>---------------------------------------------------------------
commit a1dada8ba0f88f14653f09ea8c3ec8b04d982434
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Nov 19 13:17:10 2015 -0800
Refactor main program, add documentation, slightly improve functionality.
>---------------------------------------------------------------
a1dada8ba0f88f14653f09ea8c3ec8b04d982434
.../softmax_regression/softmax_regression_main.cpp | 126 +++++++++++++++------
1 file changed, 90 insertions(+), 36 deletions(-)
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
index 992a6ac..51c37ba 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
@@ -6,8 +6,33 @@
#include <set>
// Define parameters for the executable.
-PROGRAM_INFO("Softmax Regression", "This program performs softmax regression "
- "on the given dataset and able to store the learned parameters.");
+PROGRAM_INFO("Softmax Regression", "This program performs softmax regression, "
+ "a generalization of logistic regression to the multiclass case, and has "
+ "support for L2 regularization. The program is able to train a model, load"
+ " an existing model, and give predictions (and optionally their accuracy) "
+ "for test data."
+ "\n\n"
+ "Training a softmax regression model is done by giving a file of training "
+ "points with --training_file (-t) and their corresponding labels with "
+ "--labels_file (-l). The number of classes can be manually specified with "
+ "the --number_of_classes (-n) option, and the maximum number of iterations "
+ "of the L-BFGS optimizer can be specified with the --max_iterations (-M) "
+ "option. The L2 regularization constant can be specified with --lambda "
+ "(-r), and if an intercept term is not desired in the model, the "
+ "--no_intercept (-N) can be specified."
+ "\n\n"
+ "The trained model can be saved to a file with the --output_model (-m) "
+ "option. If training is not desired, but only testing is, a model can be "
+ "loaded with the --input_model (-i) option. At the current time, a loaded "
+ "model cannot be trained further, so specifying both -i and -t is not "
+ "allowed."
+ "\n\n"
+ "The program is also able to evaluate a model on test data. A test dataset"
+ " can be specified with the --test_data (-T) option. Class predictions "
+ "will be saved in the file specified with the --predictions_file (-p) "
+ "option. If labels are specified for the test data, with the --test_labels"
+ " (-L) option, then the program will print the accuracy of the predictions "
+ "on the given test set and its corresponding labels.");
// Required options.
PARAM_STRING("training_file", "A file containing the training set (the matrix "
@@ -23,6 +48,8 @@ PARAM_STRING("output_model", "File to save trained logistic regression model "
// Testing.
PARAM_STRING("test_data", "File containing test dataset.", "T", "");
+PARAM_STRING("predictions_file", "File to save predictions for test dataset "
+ "into.", "p", "");
PARAM_STRING("test_labels", "File containing test labels.", "L", "");
// Softmax configuration options.
@@ -35,20 +62,23 @@ PARAM_INT("number_of_classes", "Number of classes for classification, "
PARAM_DOUBLE("lambda", "L2-regularization constant", "r", 0.0001);
-PARAM_FLAG("intercept", "Add intercept term, if not specify, "
- "the intercept term will not be added", "t");
+PARAM_FLAG("no_intercept", "Do not add the intercept term to the model.", "N");
using namespace std;
+// Count the number of classes in the given labels (if numClasses == 0).
size_t CalculateNumberOfClasses(const size_t numClasses,
const arma::Row<size_t>& trainLabels);
+// Test the accuracy of the model.
template<typename Model>
void TestPredictAcc(const string& testFile,
+ const string& predictionsFile,
const string& testLabels,
const size_t numClasses,
const Model& model);
+// Build the softmax model given the parameters.
template<typename Model>
std::unique_ptr<Model> TrainSoftmax(const string& trainingFile,
const string& labelFile,
@@ -81,10 +111,14 @@ int main(int argc, char** argv)
<< ")! Must be greater than or equal to 0." << endl;
const string outputModelFile = CLI::GetParam<string>("output_model");
+ const string testLabelsFile = CLI::GetParam<string>("test_labels");
+ const string predictionsFile = CLI::GetParam<string>("predictions_file");
- // Make sure we have an output file if we're not doing the work in-place.
- if (outputModelFile.empty())
- Log::Warn << "--output_model is not set; no results will be saved." << endl;
+ // Make sure we have an output file of some sort.
+ if (outputModelFile.empty() && testLabelsFile.empty() &&
+ predictionsFile.empty())
+ Log::Warn << "None of --output_model, --test_labels, or --predictions_file "
+ << "are set; no results from this program will be saved." << endl;
using SM = regression::SoftmaxRegression<>;
@@ -94,6 +128,7 @@ int main(int argc, char** argv)
maxIterations);
TestPredictAcc(CLI::GetParam<string>("test_data"),
+ CLI::GetParam<string>("predictions_file"),
CLI::GetParam<string>("test_labels"),
sm->NumClasses(), *sm);
@@ -121,6 +156,7 @@ size_t CalculateNumberOfClasses(const size_t numClasses,
template<typename Model>
void TestPredictAcc(const string& testFile,
+ const string& predictionsFile,
const string& testLabelsFile,
size_t numClasses,
const Model& model)
@@ -128,32 +164,49 @@ void TestPredictAcc(const string& testFile,
using namespace mlpack;
// If there is no test set, there is nothing to test on.
- if (testFile.empty() && testLabelsFile.empty())
+ if (testFile.empty() && predictionsFile.empty() && testLabelsFile.empty())
return;
- if ((!testFile.empty() && testLabelsFile.empty()) ||
- (testFile.empty() && !testLabelsFile.empty()))
+ if (!testLabelsFile.empty() && testFile.empty())
{
- Log::Fatal << "--test_file must be specified with --test_labels and vice"
- << " versa." << endl;
+ Log::Warn << "--test_labels specified, but --test_file is not specified."
+ << " The parameter will be ignored." << endl;
+ return;
}
- if (!testFile.empty() && !testLabelsFile.empty())
+ if (!predictionsFile.empty() && testFile.empty())
{
- arma::mat testData;
+ Log::Warn << "--predictions_file specified, but --test_file is not "
+ << "specified. The parameter will be ignored." << endl;
+ return;
+ }
+
+ // Get the test dataset, and get predictions.
+ arma::mat testData;
+ data::Load(testFile, testData, true);
+
+ arma::Row<size_t> predictLabels;
+ model.Predict(testData, predictLabels);
+
+ // Save predictions, if desired.
+ if (!predictionsFile.empty())
+ data::Save(predictionsFile, predictLabels);
+
+ // Calculate accuracy, if desired.
+ if (!testLabelsFile.empty())
+ {
+ arma::Mat<size_t> tmpTestLabels;
arma::Row<size_t> testLabels;
- testData.load(testFile, arma::auto_detect);
- testData = testData.t();
- testLabels.load(testLabelsFile, arma::auto_detect);
+ data::Load(testLabelsFile, tmpTestLabels, true);
+ testLabels = tmpTestLabels.row(0);
- if (testData.n_cols!= testLabels.n_elem)
+ if (testData.n_cols != testLabels.n_elem)
{
- Log::Fatal << "Labels of --test_labels should same as the samples size "
- << "of --test_data " << endl;
+ Log::Fatal << "Test data in --test_data has " << testData.n_cols
+ << " points, but labels in --test_labels have " << testLabels.n_elem
+ << " labels!" << endl;
}
- arma::vec predictLabels;
- model.Predict(testData, predictLabels);
std::vector<size_t> bingoLabels(numClasses, 0);
std::vector<size_t> labelSize(numClasses, 0);
for (arma::uword i = 0; i != predictLabels.n_elem; ++i)
@@ -164,16 +217,19 @@ void TestPredictAcc(const string& testFile,
}
++labelSize[testLabels(i)];
}
+
size_t totalBingo = 0;
- for(size_t i = 0; i != bingoLabels.size(); ++i)
+ for (size_t i = 0; i != bingoLabels.size(); ++i)
{
- Log::Info << "Accuracy of label " << i << " is "
- << (bingoLabels[i] / static_cast<double>(labelSize[i])) << endl;
+ Log::Info << "Accuracy for points with label " << i << " is "
+ << (bingoLabels[i] / static_cast<double>(labelSize[i])) << " ("
+ << bingoLabels[i] << " of " << labelSize[i] << ")." << endl;
totalBingo += bingoLabels[i];
}
- Log::Info << "Total accuracy is "
- << (totalBingo) / static_cast<double>(predictLabels.n_elem)
- << endl;
+
+ Log::Info << "Total accuracy for all points is "
+ << (totalBingo) / static_cast<double>(predictLabels.n_elem) << " ("
+ << totalBingo << " of " << predictLabels.n_elem << ")." << endl;
}
}
@@ -197,24 +253,22 @@ std::unique_ptr<Model> TrainSoftmax(const string& trainingFile,
{
arma::mat trainData;
arma::Row<size_t> trainLabels;
- trainData.load(trainingFile, arma::auto_detect);
- trainData = trainData.t();
- trainLabels.load(labelFile, arma::auto_detect);
+ arma::Mat<size_t> tmpTrainLabels;
//load functions of mlpack do not works on windows, it will complain
//"[FATAL] Unable to detect type of 'softmax_data.txt'; incorrect extension?"
- //data::Load(inputFile, trainData, true);
- //data::Load(labelFile, trainLabels, true);
+ data::Load(trainingFile, trainData, true);
+ data::Load(labelFile, tmpTrainLabels, true);
+ trainLabels = tmpTrainLabels.row(0);
- if(trainData.n_cols != trainLabels.n_elem)
+ if (trainData.n_cols != trainLabels.n_elem)
Log::Fatal << "Samples of input_data should same as the size of "
<< "input_label." << endl;
- //size_t numClasses = CLI::GetParam<int>("number_of_classes");
const size_t numClasses = CalculateNumberOfClasses(
(size_t) CLI::GetParam<int>("number_of_classes"), trainLabels);
- const bool intercept = !CLI::HasParam("intercept") ? false : true;
+ const bool intercept = CLI::HasParam("no_intercept") ? false : true;
SRF smFunction(trainData, trainLabels, numClasses, intercept,
CLI::GetParam<double>("lambda"));
More information about the mlpack-git
mailing list