[mlpack-git] master: Make sure training propagates through the tree. (7c9f5b6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:42 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 7c9f5b68f9872519b425c983b6105a3f7204d9c8
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 23 15:19:41 2015 -0400
Make sure training propagates through the tree.
This causes the HoeffdingTreeTest to work. I am not happy with storing the
class counts in HoeffdingSplit as well as the sufficient statistics in
HoeffdingCategoricalSplit -- this is redundant and maybe should be avoided.
>---------------------------------------------------------------
7c9f5b68f9872519b425c983b6105a3f7204d9c8
.../methods/hoeffding_trees/hoeffding_split.hpp | 1 +
.../hoeffding_trees/hoeffding_split_impl.hpp | 11 ++++++--
.../streaming_decision_tree_impl.hpp | 29 ++++++++++++++--------
3 files changed, 29 insertions(+), 12 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index c16af0a..bc96870 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -55,6 +55,7 @@ class HoeffdingSplit
size_t numSamples;
size_t numClasses;
+ arma::Col<size_t> classCounts;
const data::DatasetInfo& datasetInfo;
double successProbability;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index 609971b..2387b51 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -23,6 +23,7 @@ HoeffdingSplit<
const double successProbability) :
numSamples(0),
numClasses(numClasses),
+ classCounts(arma::zeros<arma::Col<size_t>>(numClasses)),
datasetInfo(datasetInfo),
successProbability(successProbability),
splitDimension(size_t(-1)),
@@ -49,6 +50,12 @@ void HoeffdingSplit<
{
if (splitDimension == size_t(-1))
{
+ // Update majority counts.
+ classCounts(label)++;
+ arma::uword tmp;
+ classCounts.max(tmp);
+ majorityClass = size_t(tmp);
+
++numSamples;
size_t numericIndex = 0;
size_t categoricalIndex = 0;
@@ -194,12 +201,12 @@ void HoeffdingSplit<
if (datasetInfo.Type(splitDimension) == data::Datatype::numeric)
{
- numericSplits[numericSplitIndex + 1].CreateChildren(children, datasetInfo,
+ numericSplits[numericSplitIndex].CreateChildren(children, datasetInfo,
numericSplit);
}
else if (datasetInfo.Type(splitDimension) == data::Datatype::categorical)
{
- categoricalSplits[categoricalSplitIndex + 1].CreateChildren(children,
+ categoricalSplits[categoricalSplitIndex].CreateChildren(children,
datasetInfo, categoricalSplit);
}
}
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
index 61e3f59..295c237 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
@@ -47,18 +47,27 @@ template<typename VecType>
void StreamingDecisionTree<SplitType, MatType>::Train(const VecType& data,
const size_t label)
{
- split.Train(data, label);
-
- const size_t numChildren = split.SplitCheck();
- if (numChildren > 0)
+ if (children.size() == 0)
{
- // We need to add a bunch of children.
- // Delete children, if we have them.
- if (children.size() > 0)
- children.clear();
+ split.Train(data, label);
+
+ const size_t numChildren = split.SplitCheck();
+ if (numChildren > 0)
+ {
+ // We need to add a bunch of children.
+ // Delete children, if we have them.
+ if (children.size() > 0)
+ children.clear();
- // The split knows how to add the children.
- split.CreateChildren(children);
+ // The split knows how to add the children.
+ split.CreateChildren(children);
+ }
+ }
+ else
+ {
+ // We've already split this node. But we need to train the child nodes.
+ size_t direction = split.CalculateDirection(data);
+ children[direction].Train(data, label);
}
}
More information about the mlpack-git
mailing list