[mlpack-git] master: Add better test; fix bug. (1687a1c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:12 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 1687a1c0aee16a2e1e885daf9be0fdc277104aba
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 8 15:09:23 2015 -0400

    Add better test; fix bug.


>---------------------------------------------------------------

1687a1c0aee16a2e1e885daf9be0fdc277104aba
 .../hoeffding_trees/binary_numeric_split_impl.hpp  |  3 +-
 src/mlpack/tests/hoeffding_tree_test.cpp           | 46 ++++++++++++++++++++++
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index 1c45829..776aa8d 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -59,7 +59,6 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
     ++counts((*it).second, 0);
 
     // TODO: skip ahead if the next value is the same.
-
     double value = FitnessFunction::Evaluate(counts);
     if (value > bestValue)
     {
@@ -88,7 +87,7 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
   counts.col(1) = classCounts;
 
   for (typename std::multimap<ObservationType, size_t>::const_iterator it =
-      sortedElements.begin(); (*it).second <= bestSplit; ++it)
+      sortedElements.begin(); (*it).first <= bestSplit; ++it)
   {
     // Move the point to the correct side of the split.
     --counts((*it).second, 1);
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 3267feb..6d32826 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -544,6 +544,52 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleSplitTest)
     // impurity for the children is 0.
     BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.5, 1e-5);
   }
+
+  // Now, when we ask it to split, ensure that the split value is reasonable.
+  arma::Col<size_t> childMajorities;
+  NumericSplitInfo<> splitInfo;
+  split.Split(childMajorities, splitInfo);
+
+  BOOST_REQUIRE_EQUAL(childMajorities[0], 0);
+  BOOST_REQUIRE_EQUAL(childMajorities[1], 1);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.5), 0);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(1.5), 1);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.0), 0);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(-1.0), 0);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.9), 0);
+  BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(1.1), 1);
+}
+
+/**
+ * Create a BinaryNumericSplit object, feed it samples in the same way as
+ * before, but with four classes.
+ */
+BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
+{
+  BinaryNumericSplit<GiniImpurity> split(4); // 4 classes.
+
+  // Feed it samples.
+  for (size_t i = 0; i < 250; ++i)
+  {
+    split.Train(mlpack::math::Random(), 0);
+    split.Train(mlpack::math::Random() + 2.0, 1);
+    split.Train(mlpack::math::Random() - 1.0, 2);
+    split.Train(mlpack::math::Random() + 1.0, 3);
+
+    // The same as the previous test, but with four classes: 4 * (0.25 * 0.75) =
+    // 0.75.  We can only split in one place, though, which will give one
+    // perfect child, giving a gain of 0.75 - 3 * (1/3 * 2/3) = 0.25.
+    BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.25, 1e-5);
+  }
+
+  // Now, when we ask it to split, ensure that the split value is reasonable.
+  arma::Col<size_t> childMajorities;
+  NumericSplitInfo<> splitInfo;
+  split.Split(childMajorities, splitInfo);
+
+  // We don't really care where it splits -- it can split anywhere.  But it has
+  // to split in only two directions.
+  BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 2);
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list