[mlpack-git] master: Clarify where splitting actually occurs; fix bugs. (9a4fb1d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:01 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 9a4fb1d85d8997487f5dd8c94cdae2215d1e82e8
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Nov 5 14:34:18 2015 -0800
Clarify where splitting actually occurs; fix bugs.
>---------------------------------------------------------------
9a4fb1d85d8997487f5dd8c94cdae2215d1e82e8
.../hoeffding_trees/binary_numeric_split.hpp | 4 +-
.../hoeffding_trees/binary_numeric_split_impl.hpp | 45 +++++++++++++++-------
2 files changed, 34 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
index 07f8d7b..f83c95e 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -8,6 +8,8 @@
#ifndef __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
#define __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
+#include "binary_numeric_split_info.hpp"
+
namespace mlpack {
namespace tree {
@@ -41,7 +43,7 @@ class BinaryNumericSplit
{
public:
//! The splitting information required by the BinaryNumericSplit.
- typedef NumericSplitInfo<ObservationType> SplitInfo;
+ typedef BinaryNumericSplitInfo<ObservationType> SplitInfo;
/**
* Create the BinaryNumericSplit object with the given number of classes.
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index f3e587d..468bfdd 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -51,20 +51,29 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
double bestValue = FitnessFunction::Evaluate(counts);
+ // Initialize to the first observation, so we don't calculate gain on the
+ // first iteration (it will be 0).
+ ObservationType lastObservation = (*sortedElements.begin()).first;
for (typename std::multimap<ObservationType, size_t>::const_iterator it =
sortedElements.begin(); it != sortedElements.end(); ++it)
{
+ // If this value is the same as the last, or if this is the first value,
+ // don't calculate the gain.
+ if ((*it).first != lastObservation)
+ {
+ lastObservation = (*it).first;
+
+ const double value = FitnessFunction::Evaluate(counts);
+ if (value > bestValue)
+ {
+ bestValue = value;
+ bestSplit = (*it).first;
+ }
+ }
+
// Move the point to the right side of the split.
--counts((*it).second, 1);
++counts((*it).second, 0);
-
- // TODO: skip ahead if the next value is the same.
- const double value = FitnessFunction::Evaluate(counts);
- if (value > bestValue)
- {
- bestValue = value;
- bestSplit = (*it).first;
- }
}
isAccurate = true;
@@ -86,12 +95,22 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
counts.col(0).zeros();
counts.col(1) = classCounts;
+ double min = DBL_MAX;
+ double max = -DBL_MAX;
for (typename std::multimap<ObservationType, size_t>::const_iterator it =
- sortedElements.begin(); (*it).first < bestSplit; ++it)
+ sortedElements.begin();// (*it).first < bestSplit; ++it)
+ it != sortedElements.end(); ++it)
{
// Move the point to the correct side of the split.
- --counts((*it).second, 1);
- ++counts((*it).second, 0);
+ if ((*it).first < bestSplit)
+ {
+ --counts((*it).second, 1);
+ ++counts((*it).second, 0);
+ }
+ if ((*it).first < min)
+ min = (*it).first;
+ if ((*it).first > max)
+ max = (*it).first;
}
// Calculate the majority classes of the children.
@@ -102,9 +121,7 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
childMajorities[1] = size_t(maxIndex);
// Create the according SplitInfo object.
- arma::vec splitPoints(1);
- splitPoints[0] = double(bestSplit);
- splitInfo = SplitInfo(splitPoints);
+ splitInfo = SplitInfo(bestSplit);
}
template<typename FitnessFunction, typename ObservationType>
More information about the mlpack-git
mailing list