[mlpack-git] master: Simple tests for the HoeffdingCategoricalSplit. (34d521c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:03 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 34d521c12e655ccc298850c5112e964ee4442b71
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Sep 21 15:13:27 2015 +0000
Simple tests for the HoeffdingCategoricalSplit.
>---------------------------------------------------------------
34d521c12e655ccc298850c5112e964ee4442b71
.../hoeffding_categorical_split.hpp | 1 +
.../hoeffding_categorical_split_impl.hpp | 10 ++--
src/mlpack/tests/hoeffding_tree_test.cpp | 56 ++++++++++++++++++++++
3 files changed, 62 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 14da600..551744a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -49,6 +49,7 @@ class HoeffdingCategoricalSplit
template<typename StreamingDecisionTreeType>
void CreateChildren(std::vector<StreamingDecisionTreeType*>& children,
+ data::DatasetInfo& datasetInfo,
SplitInfo& splitInfo);
size_t MajorityClass() const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index 4faf88a..7d5b74e 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -22,8 +22,8 @@ HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
sufficientStatistics.zeros();
}
-template<typename eT>
template<typename FitnessFunction>
+template<typename eT>
void HoeffdingCategoricalSplit<FitnessFunction>::Train(eT value,
const size_t label)
{
@@ -39,24 +39,24 @@ double HoeffdingCategoricalSplit<FitnessFunction>::EvaluateFitnessFunction()
return FitnessFunction::Evaluate(sufficientStatistics);
}
-template<typename StreamingDecisionTreeType>
template<typename FitnessFunction>
+template<typename StreamingDecisionTreeType>
void HoeffdingCategoricalSplit<FitnessFunction>::CreateChildren(
std::vector<StreamingDecisionTreeType*>& children,
+ data::DatasetInfo& datasetInfo,
SplitInfo& splitInfo)
{
// We'll make one child for each category.
- children.push_back(StreamingDecisionTree(datasetInfo));
+ children.push_back(StreamingDecisionTreeType(datasetInfo));
// Create the according SplitInfo object.
splitInfo = SplitInfo(sufficientStatistics.n_cols);
}
-template<typename StreamingDecisionTreeType>
template<typename FitnessFunction>
size_t HoeffdingCategoricalSplit<FitnessFunction>::MajorityClass() const
{
// Calculate the class that we have seen the most of.
- arma::Row<size_t> classCounts = sum(sufficientStatistics);
+ arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);
arma::uword maxIndex;
classCounts.max(maxIndex);
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 2659f36..91a76bf 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -7,6 +7,7 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
#include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
+#include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -97,4 +98,59 @@ BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
BOOST_REQUIRE_SMALL(GiniImpurity::Evaluate(counts), 1e-10);
}
+/**
+ * Feed the HoeffdingCategoricalSplit class many examples, all from the same
+ * class, and verify that the majority class is correct.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitMajorityClassTest)
+{
+ // Ten categories, three classes.
+ HoeffdingCategoricalSplit<GiniImpurity> split(10, 3);
+
+ for (size_t i = 0; i < 500; ++i)
+ {
+ split.Train(math::RandInt(0, 10), 1);
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+ }
+}
+
+/**
+ * A harder majority class example.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitHarderMajorityClassTest)
+{
+ // Ten categories, three classes.
+ HoeffdingCategoricalSplit<GiniImpurity> split(10, 3);
+
+ split.Train(math::RandInt(0, 10), 1);
+ for (size_t i = 0; i < 250; ++i)
+ {
+ split.Train(math::RandInt(0, 10), 1);
+ split.Train(math::RandInt(0, 10), 2);
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+ }
+}
+
+/**
+ * Ensure that the fitness function is positive when we pass some data that
+ * would result in an improvement if it was split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitEasyFitnessCheck)
+{
+ HoeffdingCategoricalSplit<GiniImpurity> split(5, 3);
+
+ for (size_t i = 0; i < 100; ++i)
+ split.Train(0, 0);
+ for (size_t i = 0; i < 100; ++i)
+ split.Train(1, 1);
+ for (size_t i = 0; i < 100; ++i)
+ split.Train(2, 1);
+ for (size_t i = 0; i < 100; ++i)
+ split.Train(3, 2);
+ for (size_t i = 0; i < 100; ++i)
+ split.Train(4, 2);
+
+ BOOST_REQUIRE_GT(split.EvaluateFitnessFunction(), 0.0);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list