[mlpack-git] master: Refactor decision_stump to allow saving/loading of models. (07888b3)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:29 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271

>---------------------------------------------------------------

commit 07888b3c0bf8c872ee329e5e3a2a9782e23e00e1
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Nov 21 02:39:39 2015 +0000

    Refactor decision_stump to allow saving/loading of models.


>---------------------------------------------------------------

07888b3c0bf8c872ee329e5e3a2a9782e23e00e1
 .../methods/decision_stump/decision_stump_main.cpp | 224 ++++++++++++++-------
 1 file changed, 156 insertions(+), 68 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
index 48ad4e3..c01394d 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -18,96 +18,184 @@ PROGRAM_INFO("Decision Stump",
     "and will split into multiple buckets.  The dimension and bins are selected"
     " by maximizing the information gain of the split.  Optionally, the minimum"
     " number of training points in each bin can be specified with the "
-    "--bin_size (-b) parameter.\n"
+    "--bucket_size (-b) parameter.\n"
     "\n"
     "The decision stump is parameterized by a splitting dimension and a vector "
     "of values that denote the splitting values of each bin.\n"
     "\n"
-    "This program allows training of a decision stump, and then application of "
-    "the learned decision stump to a test dataset.  To train a decision stump, "
-    "a training dataset must be passed to --train_file (-t).  Labels can either"
-    " be present as the last dimension of the training dataset, or given "
-    "explicitly with the --labels_file (-l) parameter.\n"
+    "This program enables several applications: a decision tree may be trained "
+    "or loaded, and then that decision tree may be used to classify a given set"
+    " of test points.  The decision tree may also be saved to a file for later "
+    "usage.\n"
     "\n"
-    "A test file is given through the --test_file (-T) parameter.  The "
-    "predicted labels for the test set will be stored in the file specified by "
-    "the --output_file (-o) parameter.");
-
-// Necessary parameters.
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
+    "To train a decision stump, training data should be passed with the "
+    "--training_file (-t) option, and their corresponding labels should be "
+    "passed with the --labels_file (-l) option.  Optionally, if --labels_file "
+    "is not specified, the labels are assumed to be the last dimension of the "
+    "training dataset.  The --bucket_size (-b) parameter controls the minimum "
+    "number of training points in each decision stump bucket.\n"
+    "\n"
+    "For classifying a test set, a decision stump may be loaded with the "
+    "--input_model_file (-m) parameter (useful for the situation where a "
+    "stump has not just been trained), and a test set may be specified with the"
+    " --test_file (-T) parameter.  The predicted labels will be saved to the "
+    "file specified with the --predictions_file (-p) parameter.\n"
+    "\n"
+    "Because decision stumps are trained in batch, retraining does not make "
+    "sense and thus it is not possible to pass both --training_file and "
+    "--input_model_file; instead, simply build a new decision stump with the "
+    "training data.\n"
+    "\n"
+    "A trained decision stump can be saved with the --output_model_file (-M) "
+    "option.  That stump may later be re-used in subsequent calls to this "
+    "program (or others).");
 
-// Output parameters (optional).
+// Datasets we might load.
+PARAM_STRING("training_file", "A file containing the training set.", "t", "");
 PARAM_STRING("labels_file", "A file containing labels for the training set. If "
     "not specified, the labels are assumed to be the last row of the training "
     "data.", "l", "");
-PARAM_STRING("output_file", "The file in which the predicted labels for the "
-    "test set will be written.", "o", "output.csv");
+PARAM_STRING("test_file", "A file containing the test set.", "T", "");
+
+// Output.
+PARAM_STRING("predictions_file", "The file in which the predicted labels for "
+    "the test set will be written.", "p", "predictions.csv");
+
+// We may load or save a model.
+PARAM_STRING("input_model_file", "File containing decision stump model to "
+    "load.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained decision stump model "
+    "to.", "M", "");
+
+PARAM_INT("bucket_size", "The minimum number of training points in each "
+    "decision stump bucket.", "b", 6);
 
-PARAM_INT("bin_size", "The minimum number of training points in each "
-    "decision stump bin.", "b", 6);
+/**
+ * This is the structure that actually saves to disk.  We have to save the
+ * label mappings, too, otherwise everything we load at test time in a future
+ * run will end up being borked.
+ */
+struct DSModel
+{
+  //! The mappings.
+  arma::Col<size_t> mappings;
+  //! The stump.
+  DecisionStump<> stump;
+
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(mappings, "mappings");
+    ar & data::CreateNVP(stump, "stump");
+  }
+};
 
 int main(int argc, char *argv[])
 {
   CLI::ParseCommandLine(argc, argv);
 
-  const string trainingDataFilename = CLI::GetParam<string>("train_file");
-  mat trainingData;
-  data::Load(trainingDataFilename, trainingData, true);
+  // Check that the parameters are reasonable.
+  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+  {
+    Log::Fatal << "Both --training_file and --input_model_file are specified, "
+        << "but a trained model cannot be retrained.  Only one of these options"
+        << " may be specified." << endl;
+  }
+
+  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+  {
+    Log::Fatal << "Neither --training_file nor --input_model_file are given; "
+        << "one must be specified." << endl;
+  }
 
-  // Load labels, if necessary.
-  mat labelsIn;
-  if (CLI::HasParam("labels_file"))
+  if (!CLI::HasParam("output_model_file") && !CLI::HasParam("predictions_file"))
   {
-    const string labelsFilename = CLI::GetParam<string>("labels_file");
-    // Load labels.
-    data::Load(labelsFilename, labelsIn, true);
+    Log::Warn << "Neither --output_model_file nor --predictions_file are "
+        << "specified; no results will be saved!" << endl;
+  }
 
-    // Do the labels need to be transposed?
-    if (labelsIn.n_rows == 1)
-      labelsIn = labelsIn.t();
+  // We must either load a model, or train a new stump.
+  DSModel model;
+  if (CLI::HasParam("training_file"))
+  {
+    const string trainingDataFilename = CLI::GetParam<string>("training_file");
+    mat trainingData;
+    data::Load(trainingDataFilename, trainingData, true);
+
+    // Load labels, if necessary.
+    Mat<size_t> labelsIn;
+    if (CLI::HasParam("labels_file"))
+    {
+      const string labelsFilename = CLI::GetParam<string>("labels_file");
+      // Load labels.
+      data::Load(labelsFilename, labelsIn, true);
+
+      // Do the labels need to be transposed?
+      if (labelsIn.n_rows == 1)
+        labelsIn = labelsIn.t();
+    }
+    else
+    {
+      // Extract the labels as the last
+      Log::Info << "Using the last dimension of training set as labels."
+          << endl;
+
+      labelsIn = arma::conv_to<arma::Mat<size_t>>::from(
+          trainingData.row(trainingData.n_rows - 1).t());
+      trainingData.shed_row(trainingData.n_rows - 1);
+    }
+
+    // Normalize the labels.
+    Col<size_t> labels;
+    data::NormalizeLabels(labelsIn.unsafe_col(0), labels, model.mappings);
+
+    const size_t bucketSize = CLI::GetParam<int>("bucket_size");
+    const size_t classes = labels.max() + 1;
+
+    Timer::Start("training");
+    model.stump.Train(trainingData, labels.t(), classes, bucketSize);
+    Timer::Stop("training");
   }
   else
   {
-    // Extract the labels as the last
-    Log::Info << "Using the last dimension of training set as labels." << endl;
-
-    labelsIn = trainingData.row(trainingData.n_rows - 1).t();
-    trainingData.shed_row(trainingData.n_rows - 1);
+    const string inputModelFile = CLI::GetParam<string>("input_model_file");
+    data::Load(inputModelFile, "decision_stump_model", model, true);
   }
 
-  // Normalize the labels.
-  Col<size_t> labels;
-  vec mappings;
-  data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
-
-  const size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
-  const size_t numClasses = labels.max() + 1;
-
-  // Load the test file.
-  const string testingDataFilename = CLI::GetParam<std::string>("test_file");
-  mat testingData;
-  data::Load(testingDataFilename, testingData, true);
-
-  if (testingData.n_rows != trainingData.n_rows)
-    Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
-        << "must be the same as training data (" << trainingData.n_rows - 1
-        << ")!" << std::endl;
-
-  Timer::Start("training");
-  DecisionStump<> ds(trainingData, labels.t(), numClasses,
-                     inpBucketSize);
-  Timer::Stop("training");
-
-  Row<size_t> predictedLabels(testingData.n_cols);
-  Timer::Start("testing");
-  ds.Classify(testingData, predictedLabels);
-  Timer::Stop("testing");
-
-  vec results;
-  data::RevertLabels(predictedLabels.t(), mappings, results);
+  // Now, do we need to do any testing?
+  if (CLI::HasParam("test_file"))
+  {
+    // Load the test file.
+    const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+    mat testingData;
+    data::Load(testingDataFilename, testingData, true);
+
+    if (testingData.n_rows <= model.stump.SplitDimension())
+      Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+          << "is too low; the trained stump requires at least "
+          << model.stump.SplitDimension() << " dimensions!" << endl;
+
+    Row<size_t> predictedLabels(testingData.n_cols);
+    Timer::Start("testing");
+    model.stump.Classify(testingData, predictedLabels);
+    Timer::Stop("testing");
+
+    // Denormalize predicted labels, if we want to save them.
+    if (CLI::HasParam("predictions_file"))
+    {
+      Col<size_t> labelsTmp = predictedLabels.t();
+      Col<size_t> actualLabels;
+      data::RevertLabels(labelsTmp, model.mappings, actualLabels);
+
+      // Save the predicted labels in a transposed form as output.
+      const string predictionsFile = CLI::GetParam<string>("predictions_file");
+      data::Save(predictionsFile, actualLabels, true, false);
+    }
+  }
 
-  // Save the predicted labels in a transposed form as output.
-  const string outputFilename = CLI::GetParam<string>("output_file");
-  data::Save(outputFilename, results, true, false);
+  // Save the model, if desired.
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"),
+        "decision_stump_model", model);
 }



More information about the mlpack-git mailing list