[mlpack-git] master: Add support for the BinaryNumericSplit. (1a19ef9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:43 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 1a19ef9bfb5e927877acc50f43bbc132d94ba64b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 18 07:29:30 2015 -0400

    Add support for the BinaryNumericSplit.
    
    Currently it crashes with datasets over a couple thousand points, but I am not
    equipped with a working gdb on this system so I can't efficiently debug it and
    will save it for later.


>---------------------------------------------------------------

1a19ef9bfb5e927877acc50f43bbc132d94ba64b
 .../streaming_decision_tree_main.cpp               | 54 +++++++++++++++++-----
 1 file changed, 43 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 6d559c8..ef17cd7 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -7,7 +7,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
-#include <stack>
+#include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
 
 using namespace std;
 using namespace mlpack;
@@ -31,20 +31,26 @@ PARAM_STRING("predictions_file", "File to output label predictions for test "
 PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
     "prediction probabilities in this file.", "P", "");
 
+PARAM_STRING("numeric_split_strategy", "The splitting strategy to use for "
+    "numeric features: 'domingos' or 'binary'.", "N", "binary");
+
+// Helper function for once we have chosen a tree type.
+template<typename TreeType>
+void PerformActions();
 
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
+  // Check input parameters for validity.
   const string trainingFile = CLI::GetParam<string>("training_file");
   const string labelsFile = CLI::GetParam<string>("labels_file");
-  const double confidence = CLI::GetParam<double>("confidence");
-  const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
   const string inputModelFile = CLI::GetParam<string>("input_model_file");
-  const string outputModelFile = CLI::GetParam<string>("output_model_file");
   const string testFile = CLI::GetParam<string>("test_file");
   const string predictionsFile = CLI::GetParam<string>("predictions_file");
   const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+  const string numericSplitStrategy =
+      CLI::GetParam<string>("numeric_split_strategy");
 
   if ((!predictionsFile.empty() || !probabilitiesFile.empty()) &&
       testFile.empty())
@@ -59,7 +65,33 @@ int main(int argc, char** argv)
     Log::Fatal << "If --training_file is specified, --labels_file must be "
         << "specified too!" << endl;
 
-  StreamingDecisionTree<HoeffdingSplit<>>* tree = NULL;
+  if (numericSplitStrategy == "domingos")
+    PerformActions<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+        HoeffdingDoubleNumericSplit, HoeffdingCategoricalSplit>>>();
+  else if (numericSplitStrategy == "binary")
+    PerformActions<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+        BinaryDoubleNumericSplit, HoeffdingCategoricalSplit>>>();
+  else
+    Log::Fatal << "Unrecognized numeric split strategy ("
+        << numericSplitStrategy << ")!  Must be 'domingos' or 'binary'."
+        << endl;
+}
+
+template<typename TreeType>
+void PerformActions()
+{
+  // Load necessary parameters.
+  const string trainingFile = CLI::GetParam<string>("training_file");
+  const string labelsFile = CLI::GetParam<string>("labels_file");
+  const double confidence = CLI::GetParam<double>("confidence");
+  const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
+  const string inputModelFile = CLI::GetParam<string>("input_model_file");
+  const string outputModelFile = CLI::GetParam<string>("output_model_file");
+  const string testFile = CLI::GetParam<string>("test_file");
+  const string predictionsFile = CLI::GetParam<string>("predictions_file");
+  const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+
+  TreeType* tree = NULL;
   DatasetInfo datasetInfo;
   if (inputModelFile.empty())
   {
@@ -67,7 +99,7 @@ int main(int argc, char** argv)
     data::Load(trainingFile, trainingSet, datasetInfo, true);
     for (size_t i = 0; i < trainingSet.n_rows; ++i)
       Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
-          << i << ".\n";
+          << i << "." << endl;
 
     arma::Col<size_t> labelsIn;
     data::Load(labelsFile, labelsIn, true, false);
@@ -75,13 +107,13 @@ int main(int argc, char** argv)
 
     // Now create the decision tree.
     Timer::Start("tree_training");
-    tree = new StreamingDecisionTree<HoeffdingSplit<>>(trainingSet, datasetInfo,
-        labels, max(labels) + 1, confidence, maxSamples);
+    tree = new TreeType(trainingSet, datasetInfo, labels, max(labels) + 1,
+        confidence, maxSamples);
     Timer::Stop("tree_training");
   }
   else
   {
-    tree = new StreamingDecisionTree<HoeffdingSplit<>>(datasetInfo, 1, 1);
+    tree = new TreeType(datasetInfo, 1, 1);
     data::Load(inputModelFile, "streamingDecisionTree", *tree, true);
 
     if (!trainingFile.empty())
@@ -90,7 +122,7 @@ int main(int argc, char** argv)
       data::Load(trainingFile, trainingSet, datasetInfo, true);
       for (size_t i = 0; i < trainingSet.n_rows; ++i)
         Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
-            << i << ".\n";
+            << i << "." << endl;
 
       arma::Col<size_t> labelsIn;
       data::Load(labelsFile, labelsIn, true, false);
@@ -103,7 +135,7 @@ int main(int argc, char** argv)
     }
   }
 
-  // Great.  Good job team.  Now work on the test set if we have one.
+  // The tree is trained or loaded.  Now do any testing if we need.
   DatasetInfo testInfo;
   if (!testFile.empty())
   {



More information about the mlpack-git mailing list