[mlpack-git] master: Add better test; fix bug. (1687a1c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:12 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 1687a1c0aee16a2e1e885daf9be0fdc277104aba
Author: ryan <ryan at ratml.org>
Date: Thu Oct 8 15:09:23 2015 -0400
Add better test; fix bug.
>---------------------------------------------------------------
1687a1c0aee16a2e1e885daf9be0fdc277104aba
.../hoeffding_trees/binary_numeric_split_impl.hpp | 3 +-
src/mlpack/tests/hoeffding_tree_test.cpp | 46 ++++++++++++++++++++++
2 files changed, 47 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index 1c45829..776aa8d 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -59,7 +59,6 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
++counts((*it).second, 0);
// TODO: skip ahead if the next value is the same.
-
double value = FitnessFunction::Evaluate(counts);
if (value > bestValue)
{
@@ -88,7 +87,7 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
counts.col(1) = classCounts;
for (typename std::multimap<ObservationType, size_t>::const_iterator it =
- sortedElements.begin(); (*it).second <= bestSplit; ++it)
+ sortedElements.begin(); (*it).first <= bestSplit; ++it)
{
// Move the point to the correct side of the split.
--counts((*it).second, 1);
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 3267feb..6d32826 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -544,6 +544,52 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleSplitTest)
// impurity for the children is 0.
BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.5, 1e-5);
}
+
+ // Now, when we ask it to split, ensure that the split value is reasonable.
+ arma::Col<size_t> childMajorities;
+ NumericSplitInfo<> splitInfo;
+ split.Split(childMajorities, splitInfo);
+
+ BOOST_REQUIRE_EQUAL(childMajorities[0], 0);
+ BOOST_REQUIRE_EQUAL(childMajorities[1], 1);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.5), 0);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(1.5), 1);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.0), 0);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(-1.0), 0);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0.9), 0);
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(1.1), 1);
+}
+
+/**
+ * Create a BinaryNumericSplit object, feed it samples in the same way as
+ * before, but with four classes.
+ */
+BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
+{
+ BinaryNumericSplit<GiniImpurity> split(4); // 4 classes.
+
+ // Feed it samples.
+ for (size_t i = 0; i < 250; ++i)
+ {
+ split.Train(mlpack::math::Random(), 0);
+ split.Train(mlpack::math::Random() + 2.0, 1);
+ split.Train(mlpack::math::Random() - 1.0, 2);
+ split.Train(mlpack::math::Random() + 1.0, 3);
+
+ // The same as the previous test, but with four classes: 4 * (0.25 * 0.75) =
+ // 0.75. We can only split in one place, though, which will give one
+ // perfect child, giving a gain of 0.75 - 3 * (1/3 * 2/3) = 0.25.
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.25, 1e-5);
+ }
+
+ // Now, when we ask it to split, ensure that the split value is reasonable.
+ arma::Col<size_t> childMajorities;
+ NumericSplitInfo<> splitInfo;
+ split.Split(childMajorities, splitInfo);
+
+ // We don't really care where it splits -- it can split anywhere. But it has
+ // to split in only two directions.
+ BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 2);
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list