[mlpack-git] master: Refactor HoeffdingTree to replace StreamingDecisionTree. (f0abc60)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:45:41 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit f0abc6003d77e8ca9f100aebe9392ae3d97b3612
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Nov 1 17:44:09 2015 +0000

    Refactor HoeffdingTree to replace StreamingDecisionTree.


>---------------------------------------------------------------

f0abc6003d77e8ca9f100aebe9392ae3d97b3612
 .../methods/hoeffding_trees/hoeffding_tree.hpp     |  65 ++++++-
 .../hoeffding_trees/hoeffding_tree_impl.hpp        | 199 +++++++++++++++++----
 2 files changed, 217 insertions(+), 47 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
index 2b780f5..27101a5 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
@@ -65,6 +65,7 @@ class HoeffdingTree
    *
    * @param data Dataset to train on.
    * @param datasetInfo Information on the dataset (types of each feature).
+   * @param labels Labels of each point in the dataset.
    * @param numClasses Number of classes in the dataset.
    * @param batchTraining Whether or not to train in batch.
    * @param successProbability Probability of success required in Hoeffding
@@ -76,8 +77,8 @@ class HoeffdingTree
    */
   template<typename MatType>
   HoeffdingTree(const MatType& data,
-                const arma::Row<size_t>& labels,
                 const data::DatasetInfo& datasetInfo,
+                const arma::Row<size_t>& labels,
                 const size_t numClasses,
                 const bool batchTraining = true,
                 const double successProbability = 0.95,
@@ -103,13 +104,21 @@ class HoeffdingTree
    */
   HoeffdingTree(const data::DatasetInfo& datasetInfo,
                 const size_t numClasses,
-                const double successProbability,
-                const size_t maxSamples,
-                const size_t checkInterval,
+                const double successProbability = 0.95,
+                const size_t maxSamples = 0,
+                const size_t checkInterval = 100,
                 std::unordered_map<size_t, std::pair<size_t, size_t>>*
                     dimensionMappings = NULL);
 
   /**
+   * Copy another tree (warning: this will duplicate the tree entirely, and may
+   * use a lot of memory.  Make sure it's what you want before you do it).
+   *
+   * @param other Tree to copy.
+   */
+  HoeffdingTree(const HoeffdingTree& other);
+
+  /**
    * Clean up memory.
    */
   ~HoeffdingTree();
@@ -147,9 +156,22 @@ class HoeffdingTree
   size_t SplitDimension() const { return splitDimension; }
 
   //! Get the majority class.
-  size_t MajorityClass() const;
+  size_t MajorityClass() const { return majorityClass; }
   //! Modify the majority class.
-  size_t& MajorityClass();
+  size_t& MajorityClass() { return majorityClass; }
+
+  //! Get the probability of the majority class (based on training samples).
+  double MajorityProbability() const { return majorityProbability; }
+  //! Modify the probability of the majority class.
+  double& MajorityProbability() { return majorityProbability; }
+
+  //! Get the number of children.
+  size_t NumChildren() const { return children.size(); }
+
+  //! Get a child.
+  const HoeffdingTree& Child(const size_t i) const { return children[i]; }
+  //! Modify a child.
+  HoeffdingTree& Child(const size_t i) { return children[i]; }
 
   /**
    * Given a point and that this node is not a leaf, calculate the index of the
@@ -187,10 +209,35 @@ class HoeffdingTree
       const;
 
   /**
+   * Classify the given points, using this node and the entire (sub)tree beneath
+   * it.  The predicted labels for each point are returned.
+   *
+   * @param data Points to classify.
+   * @param predictions Predicted labels for each point.
+   */
+  template<typename MatType>
+  void Classify(const MatType& data, arma::Row<size_t>& predictions) const;
+
+  /**
+   * Classify the given points, using this node and the entire (sub)tree beneath
+   * it.  The predicted labels for each point are returned, as well as an
+   * estimate of the probability that the prediction is correct for each point.
+   * This estimate is simply the MajorityProbability() for the leaf that each
+   * point bins to.
+   *
+   * @param data Points to classify.
+   * @param predictions Predicted labels for each point.
+   * @param probabilities Probability estimates for each predicted label.
+   */
+  template<typename MatType>
+  void Classify(const MatType& data,
+                arma::Row<size_t>& predictions,
+                arma::rowvec& probabilities) const;
+
+  /**
    * Given that this node should split, create the children.
    */
-  template<typename StreamingDecisionTreeType>
-  void CreateChildren(std::vector<StreamingDecisionTreeType>& children);
+  void CreateChildren();
 
   //! Serialize the split.
   template<typename Archive>
@@ -235,6 +282,8 @@ class HoeffdingTree
   typename CategoricalSplitType<FitnessFunction>::SplitInfo categoricalSplit;
   //! If the split is numeric, this holds the splitting information.
   typename NumericSplitType<FitnessFunction>::SplitInfo numericSplit;
+  //! If the split has occurred, these are the children.
+  std::vector<HoeffdingTree> children;
 };
 
 } // namespace tree
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index e7c5300..bed0237 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -22,15 +22,46 @@ HoeffdingTree<
     NumericSplitType,
     CategoricalSplitType
 >::HoeffdingTree(const MatType& data,
-                 const arma::Row<size_t>& labels,
                  const data::DatasetInfo& datasetInfo,
+                 const arma::Row<size_t>& labels,
                  const size_t numClasses,
                  const bool batchTraining,
                  const double successProbability,
                  const size_t maxSamples,
-                 const size_t checkInterval)
+                 const size_t checkInterval) :
+    dimensionMappings(new std::unordered_map<size_t,
+        std::pair<size_t, size_t>>()),
+    ownsMappings(true),
+    numSamples(0),
+    numClasses(numClasses),
+    maxSamples(maxSamples),
+    checkInterval(checkInterval),
+    datasetInfo(&datasetInfo),
+    successProbability(successProbability),
+    splitDimension(size_t(-1)),
+    categoricalSplit(0),
+    numericSplit()
 {
-  // Not yet implemented.
+  // Generate dimension mappings and create split objects.
+  for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
+  {
+    if (datasetInfo.Type(i) == data::Datatype::categorical)
+    {
+      categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
+          datasetInfo.NumMappings(i), numClasses));
+      (*dimensionMappings)[i] = std::make_pair(data::Datatype::categorical,
+          categoricalSplits.size() - 1);
+    }
+    else
+    {
+      numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+      (*dimensionMappings)[i] = std::make_pair(data::Datatype::numeric,
+          numericSplits.size() - 1);
+    }
+  }
+
+  // Now train.
+  Train(data, labels, batchTraining);
 }
 
 template<typename FitnessFunction,
@@ -97,6 +128,33 @@ HoeffdingTree<
   }
 }
 
+// Copy constructor.
+template<typename FitnessFunction,
+         template<typename> class NumericSplitType,
+         template<typename> class CategoricalSplitType>
+HoeffdingTree<FitnessFunction, NumericSplitType, CategoricalSplitType>::
+    HoeffdingTree(const HoeffdingTree& other) :
+    numericSplits(other.numericSplits),
+    categoricalSplits(other.categoricalSplits),
+    dimensionMappings(new std::unordered_map<size_t,
+        std::pair<size_t, size_t>>(*other.dimensionMappings)),
+    ownsMappings(true),
+    numSamples(other.numSamples),
+    numClasses(other.numClasses),
+    maxSamples(other.maxSamples),
+    checkInterval(other.checkInterval),
+    datasetInfo(new data::DatasetInfo(*other.datasetInfo)),
+    successProbability(other.successProbability),
+    splitDimension(other.splitDimension),
+    majorityClass(other.majorityClass),
+    majorityProbability(other.majorityProbability),
+    categoricalSplit(other.categoricalSplit),
+    numericSplit(other.numericSplit),
+    children(other.children)
+{
+  // Nothing left to copy.
+}
+
 template<typename FitnessFunction,
          template<typename> class NumericSplitType,
          template<typename> class CategoricalSplitType>
@@ -121,6 +179,16 @@ void HoeffdingTree<
          const bool batchTraining)
 {
   // Not yet implemented.
+  if (batchTraining)
+  {
+    throw std::invalid_argument("batch training not yet implemented");
+  }
+  else
+  {
+    // We aren't training in batch mode; loop through the points.
+    for (size_t i = 0; i < data.n_cols; ++i)
+      Train(data.col(i), labels[i]);
+  }
 }
 
 //! Train on one point.
@@ -158,11 +226,25 @@ void HoeffdingTree<
       majorityClass = numericSplits[0].MajorityClass();
       majorityProbability = numericSplits[0].MajorityProbability();
     }
+
+    // Check for a split, if we should.
+    if (numSamples % checkInterval == 0)
+    {
+      const size_t numChildren = SplitCheck();
+      if (numChildren > 0)
+      {
+        // We need to add a bunch of children.
+        // Delete children, if we have them.
+        children.clear();
+        CreateChildren();
+      }
+    }
   }
   else
   {
-    // Already split.
-    // But we should probably pass it down anyway.
+    // Already split.  Pass the training point to the relevant child.
+    size_t direction = CalculateDirection(point);
+    children[direction].Train(point, label);
   }
 }
 
@@ -175,10 +257,6 @@ size_t HoeffdingTree<
     CategoricalSplitType
 >::SplitCheck()
 {
-  // If we have not seen enough samples to check, don't check.
-  if (numSamples % checkInterval != 0)
-    return 0;
-
   // Do nothing if we've already split.
   if (splitDimension != size_t(-1))
     return 0;
@@ -250,13 +328,20 @@ template<
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
+template<typename VecType>
 size_t HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::MajorityClass() const
+>::CalculateDirection(const VecType& point) const
 {
-  return majorityClass;
+  // Don't call this before the node is split...
+  if (datasetInfo->Type(splitDimension) == data::Datatype::numeric)
+    return numericSplit.CalculateDirection(point[splitDimension]);
+  else if (datasetInfo->Type(splitDimension) == data::Datatype::categorical)
+    return categoricalSplit.CalculateDirection(point[splitDimension]);
+  else
+    return 0; // Not sure what to do here...
 }
 
 template<
@@ -264,13 +349,24 @@ template<
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
-size_t& HoeffdingTree<
+template<typename VecType>
+size_t HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::MajorityClass()
+>::Classify(const VecType& point) const
 {
-  return majorityClass;
+  if (children.size() == 0)
+  {
+    // If we're a leaf (or being considered a leaf), classify based on what we
+    // know.
+    return majorityClass;
+  }
+  else
+  {
+    // Otherwise, pass to the right child and let them classify.
+    return children[CalculateDirection(point)].Classify(point);
+  }
 }
 
 template<
@@ -279,54 +375,65 @@ template<
     template<typename> class CategoricalSplitType
 >
 template<typename VecType>
-size_t HoeffdingTree<
+void HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::CalculateDirection(const VecType& point) const
+>::Classify(const VecType& point,
+            size_t& prediction,
+            double& probability) const
 {
-  // Don't call this before the node is split...
-  if (datasetInfo->Type(splitDimension) == data::Datatype::numeric)
-    return numericSplit.CalculateDirection(point[splitDimension]);
-  else if (datasetInfo->Type(splitDimension) == data::Datatype::categorical)
-    return categoricalSplit.CalculateDirection(point[splitDimension]);
+  if (children.size() == 0)
+  {
+    // We are a leaf, so classify accordingly.
+    prediction = majorityClass;
+    probability = majorityProbability;
+  }
   else
-    return 0; // Not sure what to do here...
+  {
+    // Pass to the right child and let them do the classification.
+    children[CalculateDirection(point)].Classify(point, prediction,
+        probability);
+  }
 }
 
+//! Batch classification.
 template<
     typename FitnessFunction,
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
-template<typename VecType>
-size_t HoeffdingTree<
+template<typename MatType>
+void HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::Classify(const VecType& /* point */) const
+>::Classify(const MatType& data, arma::Row<size_t>& predictions) const
 {
-  // We're a leaf (or being considered a leaf), so classify based on what we
-  // know.
-  return majorityClass;
+  predictions.set_size(data.n_cols);
+  for (size_t i = 0; i < data.n_cols; ++i)
+    predictions[i] = Classify(data.col(i));
 }
 
+//! Batch classification with probabilities.
 template<
     typename FitnessFunction,
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
-template<typename VecType>
+template<typename MatType>
 void HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::Classify(const VecType& /* point */,
-            size_t& prediction,
-            double& probability) const
+>::Classify(const MatType& data,
+            arma::Row<size_t>& predictions,
+            arma::rowvec& probabilities) const
 {
-  prediction = majorityClass;
-  probability = majorityProbability;
+  predictions.set_size(data.n_cols);
+  probabilities.set_size(data.n_cols);
+  for (size_t i = 0; i < data.n_cols; ++i)
+    Classify(data.col(i), predictions[i], probabilities[i]);
 }
 
 template<
@@ -334,12 +441,11 @@ template<
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
-template<typename StreamingDecisionTreeType>
 void HoeffdingTree<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::CreateChildren(std::vector<StreamingDecisionTreeType>& children)
+>::CreateChildren()
 {
   // Create the children.
   arma::Col<size_t> childMajorities;
@@ -359,8 +465,8 @@ void HoeffdingTree<
   // We already know what the splitDimension will be.
   for (size_t i = 0; i < childMajorities.n_elem; ++i)
   {
-    children.push_back(StreamingDecisionTreeType(*datasetInfo, numClasses,
-        successProbability, maxSamples, checkInterval, dimensionMappings);
+    children.push_back(HoeffdingTree(*datasetInfo, numClasses,
+        successProbability, maxSamples, checkInterval, dimensionMappings));
     children[i].MajorityClass() = childMajorities[i];
   }
 
@@ -454,12 +560,27 @@ void HoeffdingTree<
   }
   else
   {
-    // We have split, so we only need to save the split.
+    // We have split, so we only need to save the split and the children.
     if (datasetInfo->Type(splitDimension) == data::Datatype::categorical)
       ar & CreateNVP(categoricalSplit, "categoricalSplit");
     else
       ar & CreateNVP(numericSplit, "numericSplit");
 
+    // Serialize the children, because we have split.
+    size_t numChildren;
+    if (Archive::is_saving::value)
+      numChildren = children.size();
+    ar & CreateNVP(numChildren, "numChildren");
+    if (Archive::is_loading::value) // If needed, allocate space.
+      children.resize(numChildren, HoeffdingTree(data::DatasetInfo(0), 0));
+
+    for (size_t i = 0; i < numChildren; ++i)
+    {
+      std::ostringstream name;
+      name << "child" << i;
+      ar & data::CreateNVP(children[i], name.str());
+    }
+
     if (Archive::is_loading::value)
     {
       numericSplits.clear();



More information about the mlpack-git mailing list