[mlpack-git] master: A more complex test for HoeffdingNumericSplit. (ebc6c94)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:50 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit ebc6c9451877243a892e304d9cd8b27a53767ff2
Author: ryan <ryan at ratml.org>
Date: Thu Sep 24 17:29:08 2015 -0400
A more complex test for HoeffdingNumericSplit.
>---------------------------------------------------------------
ebc6c9451877243a892e304d9cd8b27a53767ff2
.../hoeffding_numeric_split_impl.hpp | 9 +++--
.../methods/hoeffding_trees/numeric_split_info.hpp | 4 +-
src/mlpack/tests/hoeffding_tree_test.cpp | 46 ++++++++++++++++++++++
3 files changed, 54 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index 7d6de50..186de09 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -19,7 +19,6 @@ HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
const size_t observationsBeforeBinning) :
observations(observationsBeforeBinning - 1),
labels(observationsBeforeBinning - 1),
- splitPoints(bins),
bins(bins),
observationsBeforeBinning(observationsBeforeBinning),
samplesSeen(0),
@@ -54,8 +53,12 @@ void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Train(
max = observations[i];
}
- // Now split these.
- splitPoints = arma::linspace<arma::Col<ObservationType>>(min, max, bins);
+ // Now split these. We can't use linspace, because we don't want to include
+ // the endpoints.
+ splitPoints.resize(bins - 1);
+ const ObservationType binWidth = (max - min) / bins;
+ for (size_t i = 0; i < bins - 1; ++i)
+ splitPoints[i] = min + (i + 1) * binWidth;
++samplesSeen;
// Now, add all of the points we've seen to the sufficient statistics.
diff --git a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
index c7f2c12..ce58e58 100644
--- a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
@@ -25,14 +25,14 @@ class NumericSplitInfo
{
// What bin does the point fall into?
size_t bin = 0;
- while (value > splitPoints[bin] && bin < splitPoints.n_elem - 1)
+ while (value > splitPoints[bin] && bin < splitPoints.n_elem)
++bin;
return bin;
}
private:
- arma::Col<size_t> splitPoints;
+ arma::Col<ObservationType> splitPoints;
};
} // namespace tree
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 5cbb2d0..eb9a361 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -478,4 +478,50 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitPreBinningMajorityClassTest)
}
}
+/**
+ * Use a numeric feature that is bimodal (with a margin), and make sure that the
+ * HoeffdingNumericSplit bins it reasonably into two bins and returns sensible
+ * Gini impurity numbers.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
+{
+ // 2 classes, 2 bins, 200 samples before binning.
+ HoeffdingNumericSplit<GiniImpurity> split(2, 2, 200);
+
+ for (size_t i = 0; i < 100; ++i)
+ {
+ split.Train(mlpack::math::Random() + 0.3, 0);
+ split.Train(-mlpack::math::Random() - 0.3, 1);
+ }
+
+ // Push the majority class to 1.
+ split.Train(-mlpack::math::Random() - 0.3, 1);
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+
+ // Push the majority class back to 0.
+ split.Train(mlpack::math::Random() + 0.3, 0);
+ split.Train(mlpack::math::Random() + 0.3, 0);
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), 0);
+
+ // Now the binning should be complete, and so the impurity should be
+ // (0.5 * (1 - 0.5)) * 2 = 0.50 (it will be 0 in the two created children).
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.50, 0.01);
+
+ // Make sure that if we do create children, that the correct number of
+ // children is created, and that the bins end up in the right place.
+ std::vector<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ HoeffdingNumericSplit<double>>>> children;
+ data::DatasetInfo datasetInfo; // All numeric features -- no change necessary.
+ NumericSplitInfo<> info;
+ split.CreateChildren(children, datasetInfo, 1, info);
+ BOOST_REQUIRE_EQUAL(children.size(), 2);
+
+ // Now check the split info.
+ for (size_t i = 0; i < 10; ++i)
+ {
+ BOOST_REQUIRE_NE(info.CalculateDirection(mlpack::math::Random() + 0.3),
+ info.CalculateDirection(-mlpack::math::Random() - 0.3));
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list