[mlpack-git] master: A more complex test for HoeffdingNumericSplit. (ebc6c94)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:50 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit ebc6c9451877243a892e304d9cd8b27a53767ff2
Author: ryan <ryan at ratml.org>
Date:   Thu Sep 24 17:29:08 2015 -0400

    A more complex test for HoeffdingNumericSplit.


>---------------------------------------------------------------

ebc6c9451877243a892e304d9cd8b27a53767ff2
 .../hoeffding_numeric_split_impl.hpp               |  9 +++--
 .../methods/hoeffding_trees/numeric_split_info.hpp |  4 +-
 src/mlpack/tests/hoeffding_tree_test.cpp           | 46 ++++++++++++++++++++++
 3 files changed, 54 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index 7d6de50..186de09 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -19,7 +19,6 @@ HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
     const size_t observationsBeforeBinning) :
     observations(observationsBeforeBinning - 1),
     labels(observationsBeforeBinning - 1),
-    splitPoints(bins),
     bins(bins),
     observationsBeforeBinning(observationsBeforeBinning),
     samplesSeen(0),
@@ -54,8 +53,12 @@ void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Train(
         max = observations[i];
     }
 
-    // Now split these.
-    splitPoints = arma::linspace<arma::Col<ObservationType>>(min, max, bins);
+    // Now split these.  We can't use linspace, because we don't want to include
+    // the endpoints.
+    splitPoints.resize(bins - 1);
+    const ObservationType binWidth = (max - min) / bins;
+    for (size_t i = 0; i < bins - 1; ++i)
+      splitPoints[i] = min + (i + 1) * binWidth;
     ++samplesSeen;
 
     // Now, add all of the points we've seen to the sufficient statistics.
diff --git a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
index c7f2c12..ce58e58 100644
--- a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
@@ -25,14 +25,14 @@ class NumericSplitInfo
   {
     // What bin does the point fall into?
     size_t bin = 0;
-    while (value > splitPoints[bin] && bin < splitPoints.n_elem - 1)
+    while (value > splitPoints[bin] && bin < splitPoints.n_elem)
       ++bin;
 
     return bin;
   }
 
  private:
-  arma::Col<size_t> splitPoints;
+  arma::Col<ObservationType> splitPoints;
 };
 
 } // namespace tree
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 5cbb2d0..eb9a361 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -478,4 +478,50 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitPreBinningMajorityClassTest)
   }
 }
 
+/**
+ * Use a numeric feature that is bimodal (with a margin), and make sure that the
+ * HoeffdingNumericSplit bins it reasonably into two bins and returns sensible
+ * Gini impurity numbers.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
+{
+  // 2 classes, 2 bins, 200 samples before binning.
+  HoeffdingNumericSplit<GiniImpurity> split(2, 2, 200);
+
+  for (size_t i = 0; i < 100; ++i)
+  {
+    split.Train(mlpack::math::Random() + 0.3, 0);
+    split.Train(-mlpack::math::Random() - 0.3, 1);
+  }
+
+  // Push the majority class to 1.
+  split.Train(-mlpack::math::Random() - 0.3, 1);
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+
+  // Push the majority class back to 0.
+  split.Train(mlpack::math::Random() + 0.3, 0);
+  split.Train(mlpack::math::Random() + 0.3, 0);
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), 0);
+
+  // Now the binning should be complete, and so the impurity should be
+  // (0.5 * (1 - 0.5)) * 2 = 0.50 (it will be 0 in the two created children).
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.50, 0.01);
+
+  // Make sure that if we do create children, that the correct number of
+  // children is created, and that the bins end up in the right place.
+  std::vector<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+      HoeffdingNumericSplit<double>>>> children;
+  data::DatasetInfo datasetInfo; // All numeric features -- no change necessary.
+  NumericSplitInfo<> info;
+  split.CreateChildren(children, datasetInfo, 1, info);
+  BOOST_REQUIRE_EQUAL(children.size(), 2);
+
+  // Now check the split info.
+  for (size_t i = 0; i < 10; ++i)
+  {
+    BOOST_REQUIRE_NE(info.CalculateDirection(mlpack::math::Random() + 0.3),
+                     info.CalculateDirection(-mlpack::math::Random() - 0.3));
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list