[mlpack-git] master: Refactor to allow passing parameters of splits to children. (b69fc07)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:32 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit b69fc07d75ff34fde9ab344200b226de260f48d2
Author: ryan <ryan at ratml.org>
Date:   Mon Nov 23 13:40:48 2015 -0500

    Refactor to allow passing parameters of splits to children.


>---------------------------------------------------------------

b69fc07d75ff34fde9ab344200b226de260f48d2
 .../hoeffding_trees/binary_numeric_split.hpp       |  8 ++++
 .../hoeffding_trees/binary_numeric_split_impl.hpp  | 12 +++++
 .../hoeffding_categorical_split.hpp                | 10 ++++
 .../hoeffding_categorical_split_impl.hpp           | 10 ++++
 .../hoeffding_trees/hoeffding_numeric_split.hpp    |  7 +++
 .../hoeffding_numeric_split_impl.hpp               | 15 ++++++
 .../methods/hoeffding_trees/hoeffding_tree.hpp     | 10 +++-
 .../hoeffding_trees/hoeffding_tree_impl.hpp        | 55 ++++++++++++++++++----
 .../hoeffding_trees/hoeffding_tree_main.cpp        |  1 +
 9 files changed, 117 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
index faa0f77..5552686 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -53,6 +53,14 @@ class BinaryNumericSplit
   BinaryNumericSplit(const size_t numClasses);
 
   /**
+   * Create the BinaryNumericSplit object with the given number of classes,
+   * using information from the given other split for other parameters.  In this
+   * case, there are no other parameters, but this function is required by the
+   * HoeffdingTree class.
+   */
+  BinaryNumericSplit(const size_t numClasses, const BinaryNumericSplit& other);
+
+  /**
    * Train on the given value with the given label.
    *
    * @param value The value to train on.
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index a0a174a..eca9222 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -25,6 +25,18 @@ BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
 }
 
 template<typename FitnessFunction, typename ObservationType>
+BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
+    const size_t numClasses,
+    const BinaryNumericSplit& /* other */) :
+    classCounts(numClasses),
+    bestSplit(std::numeric_limits<ObservationType>::min()),
+    isAccurate(true)
+{
+  // Zero out class counts.
+  classCounts.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
 void BinaryNumericSplit<FitnessFunction, ObservationType>::Train(
     ObservationType value,
     const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 291fb75..f72ffe0 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -53,6 +53,16 @@ class HoeffdingCategoricalSplit
                             const size_t numClasses);
 
   /**
+   * Create the HoeffdingCategoricalSplit given a number of categories for this
+   * dimension and a number of classes and another HoeffdingCategoricalSplit to
+   * take parameters from.  In this particular case, there are no parameters to
+   * take, but this constructor is required by the HoeffdingTree class.
+   */
+  HoeffdingCategoricalSplit(const size_t numCategories,
+                            const size_t numClasses,
+                            const HoeffdingCategoricalSplit& other);
+
+  /**
    * Train on the given value with the given label.
    *
    * @param value Value to train on.
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index c008f17..c6d6521 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -23,6 +23,16 @@ HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
 }
 
 template<typename FitnessFunction>
+HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
+    const size_t numCategories,
+    const size_t numClasses,
+    const HoeffdingCategoricalSplit& /* other */) :
+    sufficientStatistics(numClasses, numCategories)
+{
+  sufficientStatistics.zeros();
+}
+
+template<typename FitnessFunction>
 template<typename eT>
 void HoeffdingCategoricalSplit<FitnessFunction>::Train(eT value,
                                                        const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index fb6c4c1..4a5e390 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -65,6 +65,13 @@ class HoeffdingNumericSplit
                         const size_t observationsBeforeBinning = 100);
 
   /**
+   * Create the HoeffdingNumericSplit class, using the parameters from the given
+   * other split object.
+   */
+  HoeffdingNumericSplit(const size_t numClasses,
+                        const HoeffdingNumericSplit& other);
+
+  /**
    * Train the HoeffdingNumericSplit on the given observed value (remember that
    * this object only cares about the information for a single feature, not an
    * entire point).
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index bcd9f92..6175d89 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -29,6 +29,21 @@ HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
 }
 
 template<typename FitnessFunction, typename ObservationType>
+HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
+    const size_t numClasses,
+    const HoeffdingNumericSplit& other) :
+    observations(other.observationsBeforeBinning - 1),
+    labels(other.observationsBeforeBinning - 1),
+    bins(other.bins),
+    observationsBeforeBinning(other.observationsBeforeBinning),
+    samplesSeen(0),
+    sufficientStatistics(arma::zeros<arma::Mat<size_t>>(numClasses, bins))
+{
+  observations.zeros();
+  labels.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
 void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Train(
     ObservationType value,
     const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
index 8e53b50..1d0f81c 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
@@ -86,7 +86,11 @@ class HoeffdingTree
                 const double successProbability = 0.95,
                 const size_t maxSamples = 0,
                 const size_t checkInterval = 100,
-                const size_t minSamples = 100);
+                const size_t minSamples = 100,
+                const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
+                    = CategoricalSplitType<FitnessFunction>(0, 0),
+                const NumericSplitType<FitnessFunction>& numericSplitIn =
+                    NumericSplitType<FitnessFunction>(0));
 
   /**
    * Construct the Hoeffding tree with the given parameters, but training on no
@@ -113,6 +117,10 @@ class HoeffdingTree
                 const size_t maxSamples = 0,
                 const size_t checkInterval = 100,
                 const size_t minSamples = 100,
+                const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
+                    = CategoricalSplitType<FitnessFunction>(0, 0),
+                const NumericSplitType<FitnessFunction>& numericSplitIn =
+                    NumericSplitType<FitnessFunction>(0),
                 std::unordered_map<size_t, std::pair<size_t, size_t>>*
                     dimensionMappings = NULL);
 
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index ce0164c..24438d6 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -29,7 +29,10 @@ HoeffdingTree<
                  const double successProbability,
                  const size_t maxSamples,
                  const size_t checkInterval,
-                 const size_t minSamples) :
+                 const size_t minSamples,
+                 const CategoricalSplitType<FitnessFunction>&
+                     categoricalSplitIn,
+                 const NumericSplitType<FitnessFunction>& numericSplitIn) :
     dimensionMappings(new std::unordered_map<size_t,
         std::pair<size_t, size_t>>()),
     ownsMappings(true),
@@ -51,13 +54,14 @@ HoeffdingTree<
     if (datasetInfo.Type(i) == data::Datatype::categorical)
     {
       categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
-          datasetInfo.NumMappings(i), numClasses));
+          datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
       (*dimensionMappings)[i] = std::make_pair(data::Datatype::categorical,
           categoricalSplits.size() - 1);
     }
     else
     {
-      numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+      numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+          numericSplitIn));
       (*dimensionMappings)[i] = std::make_pair(data::Datatype::numeric,
           numericSplits.size() - 1);
     }
@@ -80,6 +84,9 @@ HoeffdingTree<
                  const size_t maxSamples,
                  const size_t checkInterval,
                  const size_t minSamples,
+                 const CategoricalSplitType<FitnessFunction>&
+                     categoricalSplitIn,
+                 const NumericSplitType<FitnessFunction>& numericSplitIn,
                  std::unordered_map<size_t, std::pair<size_t, size_t>>*
                      dimensionMappingsIn) :
     dimensionMappings((dimensionMappingsIn != NULL) ? dimensionMappingsIn :
@@ -105,13 +112,14 @@ HoeffdingTree<
       if (datasetInfo.Type(i) == data::Datatype::categorical)
       {
         categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
-            datasetInfo.NumMappings(i), numClasses));
+            datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
         (*dimensionMappings)[i] = std::make_pair(data::Datatype::categorical,
             categoricalSplits.size() - 1);
       }
       else
       {
-        numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+        numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+            numericSplitIn));
         (*dimensionMappings)[i] = std::make_pair(data::Datatype::numeric,
             numericSplits.size() - 1);
       }
@@ -124,11 +132,12 @@ HoeffdingTree<
       if (datasetInfo.Type(i) == data::Datatype::categorical)
       {
         categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
-            datasetInfo.NumMappings(i), numClasses));
+            datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
       }
       else
       {
-        numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+        numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+            numericSplitIn));
       }
     }
   }
@@ -541,9 +550,35 @@ void HoeffdingTree<
   // We already know what the splitDimension will be.
   for (size_t i = 0; i < childMajorities.n_elem; ++i)
   {
-    children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
-        successProbability, maxSamples, checkInterval, minSamples,
-        dimensionMappings));
+    // We need to also give our split objects to the new children, so that
+    // parameters for the splits can be passed down.  But if we have no
+    // categorical or numeric features, we can't pass anything but the
+    // defaults...
+    if (categoricalSplits.size() == 0)
+    {
+      // Pass a default categorical split.
+      children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+          successProbability, maxSamples, checkInterval, minSamples,
+          CategoricalSplitType<FitnessFunction>(0, numClasses),
+          numericSplits[0], dimensionMappings));
+    }
+    else if (numericSplits.size() == 0)
+    {
+      // Pass a default numeric split.
+      children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+          successProbability, maxSamples, checkInterval, minSamples,
+          categoricalSplits[0], NumericSplitType<FitnessFunction>(numClasses),
+          dimensionMappings));
+    }
+    else
+    {
+      // Pass both splits that we already have.
+      children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+          successProbability, maxSamples, checkInterval, minSamples,
+          categoricalSplits[0], numericSplits[0], dimensionMappings));
+
+    }
+
     children[i]->MajorityClass() = childMajorities[i];
   }
 
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index 6b6634a..dd6caef 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -114,6 +114,7 @@ void PerformActions()
   const string labelsFile = CLI::GetParam<string>("labels_file");
   const double confidence = CLI::GetParam<double>("confidence");
   const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
+  const size_t minSamples = (size_t) CLI::GetParam<size_t>("min_samples");
   const string inputModelFile = CLI::GetParam<string>("input_model_file");
   const string outputModelFile = CLI::GetParam<string>("output_model_file");
   const string testFile = CLI::GetParam<string>("test_file");



More information about the mlpack-git mailing list