[mlpack-git] master: Allow specification of the number of bins and observations before binning. (00c77fc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:34 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 00c77fc789f53ba3579c43950be9726b19fd73b7
Author: ryan <ryan at ratml.org>
Date:   Mon Nov 23 16:54:36 2015 -0500

    Allow specification of the number of bins and observations before binning.


>---------------------------------------------------------------

00c77fc789f53ba3579c43950be9726b19fd73b7
 .../methods/hoeffding_trees/hoeffding_tree.hpp     |  5 +++
 .../hoeffding_trees/hoeffding_tree_main.cpp        | 40 +++++++++++++++++++---
 2 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
index 1d0f81c..ddf3d1b 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
@@ -55,6 +55,11 @@ template<typename FitnessFunction = GiniImpurity,
 class HoeffdingTree
 {
  public:
+  //! Allow access to the numeric split type.
+  typedef NumericSplitType<FitnessFunction> NumericSplit;
+  //! Allow access to the categorical split type.
+  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
+
   /**
    * Construct the Hoeffding tree with the given parameters and given training
    * data.  The tree may be trained either in batch mode (which looks at all
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index dd6caef..e2113e7 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -40,9 +40,16 @@ PARAM_FLAG("info_gain", "If set, information gain is used instead of Gini "
     "impurity for calculating Hoeffding bounds.", "i");
 PARAM_INT("passes", "Number of passes to take over the dataset.", "s", 1);
 
+PARAM_INT("bins", "If the 'domingos' split strategy is used, this specifies "
+    "the number of bins for each numeric split.", "B", 10);
+PARAM_INT("observations_before_binning", "If the 'domingos' split strategy is "
+    "used, this specifies the number of samples observed before binning is "
+    "performed.", "o", 100);
+
 // Helper function for once we have chosen a tree type.
 template<typename TreeType>
-void PerformActions();
+void PerformActions(const typename TreeType::NumericSplit& numericSplit =
+    typename TreeType::NumericSplit(0));
 
 int main(int argc, char** argv)
 {
@@ -81,33 +88,55 @@ int main(int argc, char** argv)
   if (CLI::HasParam("info_gain"))
   {
     if (numericSplitStrategy == "domingos")
+    {
+      const size_t bins = (size_t) CLI::GetParam<int>("bins");
+      const size_t observationsBeforeBinning = (size_t)
+          CLI::GetParam<int>("observations_before_binning");
+      HoeffdingDoubleNumericSplit<InformationGain> ns(0, bins,
+          observationsBeforeBinning);
       PerformActions<HoeffdingTree<InformationGain, HoeffdingDoubleNumericSplit,
-          HoeffdingCategoricalSplit>>();
+          HoeffdingCategoricalSplit>>(ns);
+    }
     else if (numericSplitStrategy == "binary")
+    {
       PerformActions<HoeffdingTree<InformationGain, BinaryDoubleNumericSplit,
           HoeffdingCategoricalSplit>>();
+    }
     else
+    {
       Log::Fatal << "Unrecognized numeric split strategy ("
           << numericSplitStrategy << ")!  Must be 'domingos' or 'binary'."
           << endl;
+    }
   }
   else
   {
     if (numericSplitStrategy == "domingos")
+    {
+      const size_t bins = (size_t) CLI::GetParam<int>("bins");
+      const size_t observationsBeforeBinning = (size_t)
+          CLI::GetParam<int>("observations_before_binning");
+      HoeffdingDoubleNumericSplit<GiniImpurity> ns(0, bins,
+          observationsBeforeBinning);
       PerformActions<HoeffdingTree<GiniImpurity, HoeffdingDoubleNumericSplit,
-          HoeffdingCategoricalSplit>>();
+          HoeffdingCategoricalSplit>>(ns);
+    }
     else if (numericSplitStrategy == "binary")
+    {
       PerformActions<HoeffdingTree<GiniImpurity, BinaryDoubleNumericSplit,
           HoeffdingCategoricalSplit>>();
+    }
     else
+    {
       Log::Fatal << "Unrecognized numeric split strategy ("
           << numericSplitStrategy << ")!  Must be 'domingos' or 'binary'."
           << endl;
+    }
   }
 }
 
 template<typename TreeType>
-void PerformActions()
+void PerformActions(const typename TreeType::NumericSplit& numericSplit)
 {
   // Load necessary parameters.
   const string trainingFile = CLI::GetParam<string>("training_file");
@@ -145,7 +174,8 @@ void PerformActions()
       Log::Info << "Taking " << passes << " passes over the dataset." << endl;
 
     tree = new TreeType(trainingSet, datasetInfo, labels, max(labels) + 1,
-        batchTraining, confidence, maxSamples, 100, minSamples);
+        batchTraining, confidence, maxSamples, 100, minSamples,
+        typename TreeType::CategoricalSplit(0, 0), numericSplit);
 
     for (size_t i = 1; i < passes; ++i)
       tree->Train(trainingSet, labels, false);



More information about the mlpack-git mailing list