[mlpack-git] master: Add some more tests for HoeffdingSplit. (a5ddd52)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:30 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit a5ddd523e8006d4246808a9661c2ea2822abfd8c
Author: ryan <ryan at ratml.org>
Date:   Tue Sep 22 16:30:47 2015 -0400

    Add some more tests for HoeffdingSplit.


>---------------------------------------------------------------

a5ddd523e8006d4246808a9661c2ea2822abfd8c
 .../methods/hoeffding_trees/hoeffding_split.hpp    |  3 +
 src/mlpack/tests/hoeffding_tree_test.cpp           | 93 ++++++++++++++++++++++
 2 files changed, 96 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index 5ca5f33..c16af0a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -34,6 +34,9 @@ class HoeffdingSplit
   // 0 if split should not happen; number of splits otherwise.
   size_t SplitCheck();
 
+  //! Get the splitting dimension (size_t(-1) if no split).
+  size_t SplitDimension() const { return splitDimension; }
+
   // Return index that we should go towards.
   template<typename VecType>
   size_t CalculateDirection(const VecType& point) const;
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 7c3e6ab..f24c2dc 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -254,4 +254,97 @@ BOOST_AUTO_TEST_CASE(HoeffdingSplitNoSplitTest)
   }
 }
 
+/**
+ * If we feed the HoeffdingSplit a ton of points of two different classes, it
+ * should very clearly suggest that we split (eventually).
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitEasySplitTest)
+{
+  // It'll be a two-dimensional dataset with two categories each.  In the first
+  // dimension, category 0 will only receive points with class 0, and category 1
+  // will only receive points with class 1.  In the second dimension, all points
+  // will have category 0 (so it is useless).
+  data::DatasetInfo info;
+  info.MapString("cat0", 0);
+  info.MapString("cat1", 0);
+  info.MapString("cat0", 1);
+
+  HoeffdingSplit<> split(2, 2, info, 0.95);
+
+  // Feed samples from each class.
+  for (size_t i = 0; i < 500; ++i)
+  {
+    split.Train(arma::Col<size_t>("0 0"), 0);
+    split.Train(arma::Col<size_t>("1 0"), 1);
+  }
+
+  // Now it should be ready to split.
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), 0);
+}
+
+/**
+ * If we force a success probability of 1, it should never split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitProbability1SplitTest)
+{
+  // It'll be a two-dimensional dataset with two categories each.  In the first
+  // dimension, category 0 will only receive points with class 0, and category 1
+  // will only receive points with class 1.  In the second dimension, all points
+  // will have category 0 (so it is useless).
+  data::DatasetInfo info;
+  info.MapString("cat0", 0);
+  info.MapString("cat1", 0);
+  info.MapString("cat0", 1);
+
+  HoeffdingSplit<> split(2, 2, info, 1.0);
+
+  // Feed samples from each class.
+  for (size_t i = 0; i < 5000; ++i)
+  {
+    split.Train(arma::Col<size_t>("0 0"), 0);
+    split.Train(arma::Col<size_t>("1 0"), 1);
+  }
+
+  // But because the success probability is 1, it should never split.
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), size_t(-1));
+}
+
+/**
+ * A slightly harder splitting problem: there are two features; one gives
+ * perfect classification, another gives almost perfect classification (with 10%
+ * error).  Splits should occur after many samples.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitAlmostPerfectSplit)
+{
+  // Two categories and two dimensions.
+  data::DatasetInfo info;
+  info.MapString("cat0", 0);
+  info.MapString("cat1", 0);
+  info.MapString("cat0", 1);
+  info.MapString("cat1", 1);
+
+  HoeffdingSplit<> split(2, 2, info, 0.95);
+
+  // Feed samples.
+  for (size_t i = 0; i < 500; ++i)
+  {
+    if (math::Random() <= 0.9)
+      split.Train(arma::Col<size_t>("0 0"), 0);
+    else
+      split.Train(arma::Col<size_t>("1 0"), 0);
+
+    if (math::Random() <= 0.9)
+      split.Train(arma::Col<size_t>("1 1"), 1);
+    else
+      split.Train(arma::Col<size_t>("0 1"), 1);
+  }
+
+  // Ensure that splitting should happen.
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+  // Make sure that it's split on the correct dimension.
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), 1);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list