[mlpack-git] master: Make sure training propagates through the tree. (7c9f5b6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:42 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 7c9f5b68f9872519b425c983b6105a3f7204d9c8
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 23 15:19:41 2015 -0400

    Make sure training propagates through the tree.
    
    This causes the HoeffdingTreeTest to work.  I am not happy with storing the
    class counts in HoeffdingSplit as well as the sufficient statistics in
    HoeffdingCategoricalSplit -- this is redundant and maybe should be avoided.


>---------------------------------------------------------------

7c9f5b68f9872519b425c983b6105a3f7204d9c8
 .../methods/hoeffding_trees/hoeffding_split.hpp    |  1 +
 .../hoeffding_trees/hoeffding_split_impl.hpp       | 11 ++++++--
 .../streaming_decision_tree_impl.hpp               | 29 ++++++++++++++--------
 3 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index c16af0a..bc96870 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -55,6 +55,7 @@ class HoeffdingSplit
 
   size_t numSamples;
   size_t numClasses;
+  arma::Col<size_t> classCounts;
   const data::DatasetInfo& datasetInfo;
   double successProbability;
 
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index 609971b..2387b51 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -23,6 +23,7 @@ HoeffdingSplit<
                   const double successProbability) :
     numSamples(0),
     numClasses(numClasses),
+    classCounts(arma::zeros<arma::Col<size_t>>(numClasses)),
     datasetInfo(datasetInfo),
     successProbability(successProbability),
     splitDimension(size_t(-1)),
@@ -49,6 +50,12 @@ void HoeffdingSplit<
 {
   if (splitDimension == size_t(-1))
   {
+    // Update majority counts.
+    classCounts(label)++;
+    arma::uword tmp;
+    classCounts.max(tmp);
+    majorityClass = size_t(tmp);
+
     ++numSamples;
     size_t numericIndex = 0;
     size_t categoricalIndex = 0;
@@ -194,12 +201,12 @@ void HoeffdingSplit<
 
   if (datasetInfo.Type(splitDimension) == data::Datatype::numeric)
   {
-    numericSplits[numericSplitIndex + 1].CreateChildren(children, datasetInfo,
+    numericSplits[numericSplitIndex].CreateChildren(children, datasetInfo,
         numericSplit);
   }
   else if (datasetInfo.Type(splitDimension) == data::Datatype::categorical)
   {
-    categoricalSplits[categoricalSplitIndex + 1].CreateChildren(children,
+    categoricalSplits[categoricalSplitIndex].CreateChildren(children,
         datasetInfo, categoricalSplit);
   }
 }
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
index 61e3f59..295c237 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
@@ -47,18 +47,27 @@ template<typename VecType>
 void StreamingDecisionTree<SplitType, MatType>::Train(const VecType& data,
                                                       const size_t label)
 {
-  split.Train(data, label);
-
-  const size_t numChildren = split.SplitCheck();
-  if (numChildren > 0)
+  if (children.size() == 0)
   {
-    // We need to add a bunch of children.
-    // Delete children, if we have them.
-    if (children.size() > 0)
-      children.clear();
+    split.Train(data, label);
+
+    const size_t numChildren = split.SplitCheck();
+    if (numChildren > 0)
+    {
+      // We need to add a bunch of children.
+      // Delete children, if we have them.
+      if (children.size() > 0)
+        children.clear();
 
-    // The split knows how to add the children.
-    split.CreateChildren(children);
+      // The split knows how to add the children.
+      split.CreateChildren(children);
+    }
+  }
+  else
+  {
+    // We've already split this node.  But we need to train the child nodes.
+    size_t direction = split.CalculateDirection(data);
+    children[direction].Train(data, label);
   }
 }
 



More information about the mlpack-git mailing list