[mlpack-git] master: Refactor main program, add documentation, slightly improve functionality. (a1dada8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Nov 19 16:17:33 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/319205b2f3103187c584db302b1a3683aa2fbfdf...a1dada8ba0f88f14653f09ea8c3ec8b04d982434

>---------------------------------------------------------------

commit a1dada8ba0f88f14653f09ea8c3ec8b04d982434
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Nov 19 13:17:10 2015 -0800

    Refactor main program, add documentation, slightly improve functionality.


>---------------------------------------------------------------

a1dada8ba0f88f14653f09ea8c3ec8b04d982434
 .../softmax_regression/softmax_regression_main.cpp | 126 +++++++++++++++------
 1 file changed, 90 insertions(+), 36 deletions(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
index 992a6ac..51c37ba 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
@@ -6,8 +6,33 @@
 #include <set>
 
 // Define parameters for the executable.
-PROGRAM_INFO("Softmax Regression", "This program performs softmax regression "
-    "on the given dataset and able to store the learned parameters.");
+PROGRAM_INFO("Softmax Regression", "This program performs softmax regression, "
+    "a generalization of logistic regression to the multiclass case, and has "
+    "support for L2 regularization.  The program is able to train a model, load"
+    " an existing model, and give predictions (and optionally their accuracy) "
+    "for test data."
+    "\n\n"
+    "Training a softmax regression model is done by giving a file of training "
+    "points with --training_file (-t) and their corresponding labels with "
+    "--labels_file (-l).  The number of classes can be manually specified with "
+    "the --number_of_classes (-n) option, and the maximum number of iterations "
+    "of the L-BFGS optimizer can be specified with the --max_iterations (-M) "
+    "option.  The L2 regularization constant can be specified with --lambda "
+    "(-r), and if an intercept term is not desired in the model, the "
+    "--no_intercept (-N) can be specified."
+    "\n\n"
+    "The trained model can be saved to a file with the --output_model (-m) "
+    "option.  If training is not desired, but only testing is, a model can be "
+    "loaded with the --input_model (-i) option.  At the current time, a loaded "
+    "model cannot be trained further, so specifying both -i and -t is not "
+    "allowed."
+    "\n\n"
+    "The program is also able to evaluate a model on test data.  A test dataset"
+    " can be specified with the --test_data (-T) option.  Class predictions "
+    "will be saved in the file specified with the --predictions_file (-p) "
+    "option.  If labels are specified for the test data, with the --test_labels"
+    " (-L) option, then the program will print the accuracy of the predictions "
+    "on the given test set and its corresponding labels.");
 
 // Required options.
 PARAM_STRING("training_file", "A file containing the training set (the matrix "
@@ -23,6 +48,8 @@ PARAM_STRING("output_model", "File to save trained logistic regression model "
 
 // Testing.
 PARAM_STRING("test_data", "File containing test dataset.", "T", "");
+PARAM_STRING("predictions_file", "File to save predictions for test dataset "
+    "into.", "p", "");
 PARAM_STRING("test_labels", "File containing test labels.", "L", "");
 
 // Softmax configuration options.
@@ -35,20 +62,23 @@ PARAM_INT("number_of_classes", "Number of classes for classification, "
 
 PARAM_DOUBLE("lambda", "L2-regularization constant", "r", 0.0001);
 
-PARAM_FLAG("intercept", "Add intercept term, if not specify, "
-           "the intercept term will not be added", "t");
+PARAM_FLAG("no_intercept", "Do not add the intercept term to the model.", "N");
 
 using namespace std;
 
+// Count the number of classes in the given labels (if numClasses == 0).
 size_t CalculateNumberOfClasses(const size_t numClasses,
                                 const arma::Row<size_t>& trainLabels);
 
+// Test the accuracy of the model.
 template<typename Model>
 void TestPredictAcc(const string& testFile,
+                    const string& predictionsFile,
                     const string& testLabels,
                     const size_t numClasses,
                     const Model& model);
 
+// Build the softmax model given the parameters.
 template<typename Model>
 std::unique_ptr<Model> TrainSoftmax(const string& trainingFile,
                                     const string& labelFile,
@@ -81,10 +111,14 @@ int main(int argc, char** argv)
         << ")! Must be greater than or equal to 0." << endl;
 
   const string outputModelFile = CLI::GetParam<string>("output_model");
+  const string testLabelsFile = CLI::GetParam<string>("test_labels");
+  const string predictionsFile = CLI::GetParam<string>("predictions_file");
 
-  // Make sure we have an output file if we're not doing the work in-place.
-  if (outputModelFile.empty())
-    Log::Warn << "--output_model is not set; no results will be saved." << endl;
+  // Make sure we have an output file of some sort.
+  if (outputModelFile.empty() && testLabelsFile.empty() &&
+      predictionsFile.empty())
+    Log::Warn << "None of --output_model, --test_labels, or --predictions_file "
+        << "are set; no results from this program will be saved." << endl;
 
 
   using SM = regression::SoftmaxRegression<>;
@@ -94,6 +128,7 @@ int main(int argc, char** argv)
                                             maxIterations);
 
   TestPredictAcc(CLI::GetParam<string>("test_data"),
+                 CLI::GetParam<string>("predictions_file"),
                  CLI::GetParam<string>("test_labels"),
                  sm->NumClasses(), *sm);
 
@@ -121,6 +156,7 @@ size_t CalculateNumberOfClasses(const size_t numClasses,
 
 template<typename Model>
 void TestPredictAcc(const string& testFile,
+                    const string& predictionsFile,
                     const string& testLabelsFile,
                     size_t numClasses,
                     const Model& model)
@@ -128,32 +164,49 @@ void TestPredictAcc(const string& testFile,
   using namespace mlpack;
 
   // If there is no test set, there is nothing to test on.
-  if (testFile.empty() && testLabelsFile.empty())
+  if (testFile.empty() && predictionsFile.empty() && testLabelsFile.empty())
     return;
 
-  if ((!testFile.empty() && testLabelsFile.empty()) ||
-       (testFile.empty() && !testLabelsFile.empty()))
+  if (!testLabelsFile.empty() && testFile.empty())
   {
-    Log::Fatal << "--test_file must be specified with --test_labels and vice"
-        << " versa." << endl;
+    Log::Warn << "--test_labels specified, but --test_file is not specified."
+        << "  The parameter will be ignored." << endl;
+    return;
   }
 
-  if (!testFile.empty() && !testLabelsFile.empty())
+  if (!predictionsFile.empty() && testFile.empty())
   {
-    arma::mat testData;
+    Log::Warn << "--predictions_file specified, but --test_file is not "
+        << "specified.  The parameter will be ignored." << endl;
+    return;
+  }
+
+  // Get the test dataset, and get predictions.
+  arma::mat testData;
+  data::Load(testFile, testData, true);
+
+  arma::Row<size_t> predictLabels;
+  model.Predict(testData, predictLabels);
+
+  // Save predictions, if desired.
+  if (!predictionsFile.empty())
+    data::Save(predictionsFile, predictLabels);
+
+  // Calculate accuracy, if desired.
+  if (!testLabelsFile.empty())
+  {
+    arma::Mat<size_t> tmpTestLabels;
     arma::Row<size_t> testLabels;
-    testData.load(testFile, arma::auto_detect);
-    testData = testData.t();
-    testLabels.load(testLabelsFile, arma::auto_detect);
+    data::Load(testLabelsFile, tmpTestLabels, true);
+    testLabels = tmpTestLabels.row(0);
 
-    if (testData.n_cols!= testLabels.n_elem)
+    if (testData.n_cols != testLabels.n_elem)
     {
-      Log::Fatal << "Labels of --test_labels should same as the samples size "
-          << "of --test_data " << endl;
+      Log::Fatal << "Test data in --test_data has " << testData.n_cols
+          << " points, but labels in --test_labels have " << testLabels.n_elem
+          << " labels!" << endl;
     }
 
-    arma::vec predictLabels;
-    model.Predict(testData, predictLabels);
     std::vector<size_t> bingoLabels(numClasses, 0);
     std::vector<size_t> labelSize(numClasses, 0);
     for (arma::uword i = 0; i != predictLabels.n_elem; ++i)
@@ -164,16 +217,19 @@ void TestPredictAcc(const string& testFile,
       }
       ++labelSize[testLabels(i)];
     }
+
     size_t totalBingo = 0;
-    for(size_t i = 0; i != bingoLabels.size(); ++i)
+    for (size_t i = 0; i != bingoLabels.size(); ++i)
     {
-      Log::Info << "Accuracy of label " << i << " is "
-          << (bingoLabels[i] / static_cast<double>(labelSize[i])) << endl;
+      Log::Info << "Accuracy for points with label " << i << " is "
+          << (bingoLabels[i] / static_cast<double>(labelSize[i])) << " ("
+          << bingoLabels[i] << " of " << labelSize[i] << ")." << endl;
       totalBingo += bingoLabels[i];
     }
-    Log::Info << "Total accuracy is "
-        << (totalBingo) / static_cast<double>(predictLabels.n_elem)
-        << endl;
+
+    Log::Info << "Total accuracy for all points is "
+        << (totalBingo) / static_cast<double>(predictLabels.n_elem) << " ("
+        << totalBingo << " of " << predictLabels.n_elem << ")." << endl;
   }
 }
 
@@ -197,24 +253,22 @@ std::unique_ptr<Model> TrainSoftmax(const string& trainingFile,
   {
     arma::mat trainData;
     arma::Row<size_t> trainLabels;
-    trainData.load(trainingFile, arma::auto_detect);
-    trainData = trainData.t();
-    trainLabels.load(labelFile, arma::auto_detect);
+    arma::Mat<size_t> tmpTrainLabels;
 
     //load functions of mlpack do not works on windows, it will complain
     //"[FATAL] Unable to detect type of 'softmax_data.txt'; incorrect extension?"
-    //data::Load(inputFile, trainData, true);
-    //data::Load(labelFile, trainLabels, true);
+    data::Load(trainingFile, trainData, true);
+    data::Load(labelFile, tmpTrainLabels, true);
+    trainLabels = tmpTrainLabels.row(0);
 
-    if(trainData.n_cols != trainLabels.n_elem)
+    if (trainData.n_cols != trainLabels.n_elem)
       Log::Fatal << "Samples of input_data should same as the size of "
           << "input_label." << endl;
 
-    //size_t numClasses = CLI::GetParam<int>("number_of_classes");
     const size_t numClasses = CalculateNumberOfClasses(
         (size_t) CLI::GetParam<int>("number_of_classes"), trainLabels);
 
-    const bool intercept = !CLI::HasParam("intercept") ? false : true;
+    const bool intercept = CLI::HasParam("no_intercept") ? false : true;
 
     SRF smFunction(trainData, trainLabels, numClasses, intercept,
         CLI::GetParam<double>("lambda"));



More information about the mlpack-git mailing list