[mlpack-git] master: Better main executable. (6cdc7c5)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:06 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 6cdc7c5d230a5f546bd428c687ce715c22bea53a
Author: ryan <ryan at ratml.org>
Date:   Fri Oct 2 01:07:21 2015 -0400

    Better main executable.


>---------------------------------------------------------------

6cdc7c5d230a5f546bd428c687ce715c22bea53a
 .../streaming_decision_tree_main.cpp               | 138 ++++++++++++++-------
 1 file changed, 91 insertions(+), 47 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 93659a7..6d559c8 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -14,7 +14,7 @@ using namespace mlpack;
 using namespace mlpack::tree;
 using namespace mlpack::data;
 
-PARAM_STRING_REQ("training_file", "Training dataset file.", "t");
+PARAM_STRING("training_file", "Training dataset file.", "t", "");
 PARAM_STRING("labels_file", "Labels for training dataset.", "l", "");
 
 PARAM_DOUBLE("confidence", "Confidence before splitting (between 0 and 1).",
@@ -22,9 +22,14 @@ PARAM_DOUBLE("confidence", "Confidence before splitting (between 0 and 1).",
 PARAM_INT("max_samples", "Maximum number of samples before splitting.", "n",
     5000);
 
-PARAM_STRING("model_file", "File to save trained tree to.", "m", "");
+PARAM_STRING("input_model_file", "File to load trained tree from.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained tree to.", "M", "");
 
 PARAM_STRING("test_file", "File of testing data.", "T", "");
+PARAM_STRING("predictions_file", "File to output label predictions for test "
+    "data into.", "p", "");
+PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
+    "prediction probabilities in this file.", "P", "");
 
 
 int main(int argc, char** argv)
@@ -35,55 +40,94 @@ int main(int argc, char** argv)
   const string labelsFile = CLI::GetParam<string>("labels_file");
   const double confidence = CLI::GetParam<double>("confidence");
   const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
-
-  arma::mat trainingSet;
+  const string inputModelFile = CLI::GetParam<string>("input_model_file");
+  const string outputModelFile = CLI::GetParam<string>("output_model_file");
+  const string testFile = CLI::GetParam<string>("test_file");
+  const string predictionsFile = CLI::GetParam<string>("predictions_file");
+  const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+
+  if ((!predictionsFile.empty() || !probabilitiesFile.empty()) &&
+      testFile.empty())
+    Log::Fatal << "--test_file must be specified if --predictions_file or "
+        << "--probabilities_file is specified." << endl;
+
+  if (trainingFile.empty() && inputModelFile.empty())
+    Log::Fatal << "One of --training_file or --input_model_file must be "
+        << "specified!" << endl;
+
+  if (!trainingFile.empty() && labelsFile.empty())
+    Log::Fatal << "If --training_file is specified, --labels_file must be "
+        << "specified too!" << endl;
+
+  StreamingDecisionTree<HoeffdingSplit<>>* tree = NULL;
   DatasetInfo datasetInfo;
-  data::Load(trainingFile, trainingSet, datasetInfo, true);
-  for (size_t i = 0; i < trainingSet.n_rows; ++i)
-    Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension " << i <<
-".\n";
-
-  arma::Col<size_t> labelsIn;
-  data::Load(labelsFile, labelsIn, true, false);
-  arma::Row<size_t> labels = labelsIn.t();
-
-  // Now create the decision tree.
-  Timer::Start("tree_training");
-  StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
-      max(labels) + 1, confidence, maxSamples);
-  Timer::Stop("tree_training");
-
-  // Great.  Good job team.
-  std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
-  stack.push(&tree);
-  size_t nodes = 0;
-  while (!stack.empty())
+  if (inputModelFile.empty())
+  {
+    arma::mat trainingSet;
+    data::Load(trainingFile, trainingSet, datasetInfo, true);
+    for (size_t i = 0; i < trainingSet.n_rows; ++i)
+      Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
+          << i << ".\n";
+
+    arma::Col<size_t> labelsIn;
+    data::Load(labelsFile, labelsIn, true, false);
+    arma::Row<size_t> labels = labelsIn.t();
+
+    // Now create the decision tree.
+    Timer::Start("tree_training");
+    tree = new StreamingDecisionTree<HoeffdingSplit<>>(trainingSet, datasetInfo,
+        labels, max(labels) + 1, confidence, maxSamples);
+    Timer::Stop("tree_training");
+  }
+  else
+  {
+    tree = new StreamingDecisionTree<HoeffdingSplit<>>(datasetInfo, 1, 1);
+    data::Load(inputModelFile, "streamingDecisionTree", *tree, true);
+
+    if (!trainingFile.empty())
+    {
+      arma::mat trainingSet;
+      data::Load(trainingFile, trainingSet, datasetInfo, true);
+      for (size_t i = 0; i < trainingSet.n_rows; ++i)
+        Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
+            << i << ".\n";
+
+      arma::Col<size_t> labelsIn;
+      data::Load(labelsFile, labelsIn, true, false);
+      arma::Row<size_t> labels = labelsIn.t();
+
+      // Now create the decision tree.
+      Timer::Start("tree_training");
+      tree->Train(trainingSet, labels);
+      Timer::Stop("tree_training");
+    }
+  }
+
+  // Great.  Good job team.  Now work on the test set if we have one.
+  DatasetInfo testInfo;
+  if (!testFile.empty())
   {
-    StreamingDecisionTree<HoeffdingSplit<>>* node = stack.top();
-    stack.pop();
-    ++nodes;
+    arma::mat testSet;
+    data::Load(testFile, testSet, testInfo, true);
+
+    arma::Row<size_t> predictions;
+    arma::rowvec probabilities;
+
+    Timer::Start("tree_testing");
+    tree->Classify(testSet, predictions, probabilities);
+    Timer::Stop("tree_testing");
 
-    for (size_t i = 0; i < node->NumChildren(); ++i)
-      stack.push(&node->Child(i));
+    if (!predictionsFile.empty())
+      data::Save(predictionsFile, predictions);
+
+    if (!probabilitiesFile.empty())
+      data::Save(probabilitiesFile, probabilities);
   }
-  Log::Info << nodes << " nodes in tree.\n";
 
   // Check the accuracy on the training set.
-  Timer::Start("tree_testing");
-  arma::Row<size_t> predictedLabels;
-  tree.Classify(trainingSet, predictedLabels);
-  Timer::Stop("tree_testing");
-
-  size_t correct = 0;
-  for (size_t i = 0; i < predictedLabels.n_elem; ++i)
-    if (labels[i] == predictedLabels[i])
-      ++correct;
-    else if (predictedLabels[i] > 10)
-      Log::Warn << "Invalid label " << predictedLabels[i] << " for point " << i
-          << "!\n";
-
-  Log::Info << correct << " correct out of " << predictedLabels.n_elem << ".\n";
-
-  const string modelFile = CLI::GetParam<string>("model_file");
-  data::Save(modelFile, "streamingDecisionTree", tree, true);
+  if (!outputModelFile.empty())
+    data::Save(outputModelFile, "streamingDecisionTree", tree, true);
+
+  // Clean up memory.
+  delete tree;
 }



More information about the mlpack-git mailing list