[mlpack-git] master: Fix compilation errors; add a first test. (06eff7a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:10 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 06eff7a081d51b91e4a02ffd9176c42499ab614e
Author: ryan <ryan at ratml.org>
Date: Thu Oct 8 11:23:52 2015 -0400
Fix compilation errors; add a first test.
>---------------------------------------------------------------
06eff7a081d51b91e4a02ffd9176c42499ab614e
src/mlpack/methods/hoeffding_trees/CMakeLists.txt | 2 ++
.../hoeffding_trees/binary_numeric_split.hpp | 4 ++--
.../hoeffding_trees/binary_numeric_split_impl.hpp | 10 +++++-----
src/mlpack/tests/hoeffding_tree_test.cpp | 23 ++++++++++++++++++++++
4 files changed, 32 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
index 957e609..9b79011 100644
--- a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
+++ b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
@@ -1,6 +1,8 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
+ binary_numeric_split.hpp
+ binary_numeric_split_impl.hpp
categorical_split_info.hpp
gini_impurity.hpp
hoeffding_categorical_split.hpp
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
index d2d7662..5873594 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -43,9 +43,9 @@ class BinaryNumericSplit
void Train(ObservationType value, const size_t label);
- double EvaluateFitnessFunction() const;
+ double EvaluateFitnessFunction();
- void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo) const;
+ void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo);
size_t MajorityClass() const;
double MajorityProbability() const;
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index 15da9b2..1c45829 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -39,7 +39,7 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Train(
template<typename FitnessFunction, typename ObservationType>
double BinaryNumericSplit<FitnessFunction, ObservationType>::
- EvaluateFitnessFunction() const
+ EvaluateFitnessFunction()
{
// Unfortunately, we have to iterate over the map.
bestSplit = std::numeric_limits<ObservationType>::min();
@@ -51,7 +51,7 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
double bestValue = FitnessFunction::Evaluate(counts);
- for (std::multimap<ObservationType, size_t>::const_iterator it =
+ for (typename std::multimap<ObservationType, size_t>::const_iterator it =
sortedElements.begin(); it != sortedElements.end(); ++it)
{
// Move the point to the right side of the split.
@@ -75,7 +75,7 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
template<typename FitnessFunction, typename ObservationType>
void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
arma::Col<size_t>& childMajorities,
- SplitInfo& splitInfo) const
+ SplitInfo& splitInfo)
{
if (!isAccurate)
EvaluateFitnessFunction();
@@ -87,8 +87,8 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
counts.col(0).zeros();
counts.col(1) = classCounts;
- for (std::multimap<ObservationType, size_t>::const_iterator it =
- sortedElements.begin(); (*it).second <= bestValue; ++it)
+ for (typename std::multimap<ObservationType, size_t>::const_iterator it =
+ sortedElements.begin(); (*it).second <= bestSplit; ++it)
{
// Move the point to the correct side of the split.
--counts((*it).second, 1);
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 6d628c5..3267feb 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -9,6 +9,7 @@
#include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
+#include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -523,4 +524,26 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
}
}
+/**
+ * Create a BinaryNumericSplit object, feed it a bunch of samples where anything
+ * less than 1.0 is class 0 and anything greater is class 1. Then make sure it
+ * can perform a perfect split.
+ */
+BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleSplitTest)
+{
+ BinaryNumericSplit<GiniImpurity> split(2); // 2 classes.
+
+ // Feed it samples.
+ for (size_t i = 0; i < 500; ++i)
+ {
+ split.Train(mlpack::math::Random(), 0);
+ split.Train(mlpack::math::Random() + 1.0, 1);
+
+ // Now ensure the fitness function gives good gain.
+ // The Gini impurity for the unsplit node is 2 * (0.5^2) = 0.5, and the Gini
+ // impurity for the children is 0.
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.5, 1e-5);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list