[mlpack-git] master: Add some more tests for HoeffdingSplit. (a5ddd52)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:30 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit a5ddd523e8006d4246808a9661c2ea2822abfd8c
Author: ryan <ryan at ratml.org>
Date: Tue Sep 22 16:30:47 2015 -0400
Add some more tests for HoeffdingSplit.
>---------------------------------------------------------------
a5ddd523e8006d4246808a9661c2ea2822abfd8c
.../methods/hoeffding_trees/hoeffding_split.hpp | 3 +
src/mlpack/tests/hoeffding_tree_test.cpp | 93 ++++++++++++++++++++++
2 files changed, 96 insertions(+)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index 5ca5f33..c16af0a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -34,6 +34,9 @@ class HoeffdingSplit
// 0 if split should not happen; number of splits otherwise.
size_t SplitCheck();
+ //! Get the splitting dimension (size_t(-1) if no split).
+ size_t SplitDimension() const { return splitDimension; }
+
// Return index that we should go towards.
template<typename VecType>
size_t CalculateDirection(const VecType& point) const;
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 7c3e6ab..f24c2dc 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -254,4 +254,97 @@ BOOST_AUTO_TEST_CASE(HoeffdingSplitNoSplitTest)
}
}
+/**
+ * If we feed the HoeffdingSplit a ton of points of two different classes, it
+ * should very clearly suggest that we split (eventually).
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitEasySplitTest)
+{
+ // It'll be a two-dimensional dataset with two categories each. In the first
+ // dimension, category 0 will only receive points with class 0, and category 1
+ // will only receive points with class 1. In the second dimension, all points
+ // will have category 0 (so it is useless).
+ data::DatasetInfo info;
+ info.MapString("cat0", 0);
+ info.MapString("cat1", 0);
+ info.MapString("cat0", 1);
+
+ HoeffdingSplit<> split(2, 2, info, 0.95);
+
+ // Feed samples from each class.
+ for (size_t i = 0; i < 500; ++i)
+ {
+ split.Train(arma::Col<size_t>("0 0"), 0);
+ split.Train(arma::Col<size_t>("1 0"), 1);
+ }
+
+ // Now it should be ready to split.
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), 0);
+}
+
+/**
+ * If we force a success probability of 1, it should never split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitProbability1SplitTest)
+{
+ // It'll be a two-dimensional dataset with two categories each. In the first
+ // dimension, category 0 will only receive points with class 0, and category 1
+ // will only receive points with class 1. In the second dimension, all points
+ // will have category 0 (so it is useless).
+ data::DatasetInfo info;
+ info.MapString("cat0", 0);
+ info.MapString("cat1", 0);
+ info.MapString("cat0", 1);
+
+ HoeffdingSplit<> split(2, 2, info, 1.0);
+
+ // Feed samples from each class.
+ for (size_t i = 0; i < 5000; ++i)
+ {
+ split.Train(arma::Col<size_t>("0 0"), 0);
+ split.Train(arma::Col<size_t>("1 0"), 1);
+ }
+
+ // But because the success probability is 1, it should never split.
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), size_t(-1));
+}
+
+/**
+ * A slightly harder splitting problem: there are two features; one gives
+ * perfect classification, another gives almost perfect classification (with 10%
+ * error). Splits should occur after many samples.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitAlmostPerfectSplit)
+{
+ // Two categories and two dimensions.
+ data::DatasetInfo info;
+ info.MapString("cat0", 0);
+ info.MapString("cat1", 0);
+ info.MapString("cat0", 1);
+ info.MapString("cat1", 1);
+
+ HoeffdingSplit<> split(2, 2, info, 0.95);
+
+ // Feed samples.
+ for (size_t i = 0; i < 500; ++i)
+ {
+ if (math::Random() <= 0.9)
+ split.Train(arma::Col<size_t>("0 0"), 0);
+ else
+ split.Train(arma::Col<size_t>("1 0"), 0);
+
+ if (math::Random() <= 0.9)
+ split.Train(arma::Col<size_t>("1 1"), 1);
+ else
+ split.Train(arma::Col<size_t>("0 1"), 1);
+ }
+
+ // Ensure that splitting should happen.
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+ // Make sure that it's split on the correct dimension.
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), 1);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list