[mlpack-git] master: Simple tests for the HoeffdingCategoricalSplit. (34d521c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:03 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 34d521c12e655ccc298850c5112e964ee4442b71
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Sep 21 15:13:27 2015 +0000

    Simple tests for the HoeffdingCategoricalSplit.


>---------------------------------------------------------------

34d521c12e655ccc298850c5112e964ee4442b71
 .../hoeffding_categorical_split.hpp                |  1 +
 .../hoeffding_categorical_split_impl.hpp           | 10 ++--
 src/mlpack/tests/hoeffding_tree_test.cpp           | 56 ++++++++++++++++++++++
 3 files changed, 62 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 14da600..551744a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -49,6 +49,7 @@ class HoeffdingCategoricalSplit
 
   template<typename StreamingDecisionTreeType>
   void CreateChildren(std::vector<StreamingDecisionTreeType*>& children,
+                      data::DatasetInfo& datasetInfo,
                       SplitInfo& splitInfo);
 
   size_t MajorityClass() const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index 4faf88a..7d5b74e 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -22,8 +22,8 @@ HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
   sufficientStatistics.zeros();
 }
 
-template<typename eT>
 template<typename FitnessFunction>
+template<typename eT>
 void HoeffdingCategoricalSplit<FitnessFunction>::Train(eT value,
                                                        const size_t label)
 {
@@ -39,24 +39,24 @@ double HoeffdingCategoricalSplit<FitnessFunction>::EvaluateFitnessFunction()
   return FitnessFunction::Evaluate(sufficientStatistics);
 }
 
-template<typename StreamingDecisionTreeType>
 template<typename FitnessFunction>
+template<typename StreamingDecisionTreeType>
 void HoeffdingCategoricalSplit<FitnessFunction>::CreateChildren(
     std::vector<StreamingDecisionTreeType*>& children,
+    data::DatasetInfo& datasetInfo,
     SplitInfo& splitInfo)
 {
   // We'll make one child for each category.
-  children.push_back(StreamingDecisionTree(datasetInfo));
+  children.push_back(StreamingDecisionTreeType(datasetInfo));
   // Create the according SplitInfo object.
   splitInfo = SplitInfo(sufficientStatistics.n_cols);
 }
 
-template<typename StreamingDecisionTreeType>
 template<typename FitnessFunction>
 size_t HoeffdingCategoricalSplit<FitnessFunction>::MajorityClass() const
 {
   // Calculate the class that we have seen the most of.
-  arma::Row<size_t> classCounts = sum(sufficientStatistics);
+  arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);
 
   arma::uword maxIndex;
   classCounts.max(maxIndex);
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 2659f36..91a76bf 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -7,6 +7,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
+#include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -97,4 +98,59 @@ BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
   BOOST_REQUIRE_SMALL(GiniImpurity::Evaluate(counts), 1e-10);
 }
 
+/**
+ * Feed the HoeffdingCategoricalSplit class many examples, all from the same
+ * class, and verify that the majority class is correct.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitMajorityClassTest)
+{
+  // Ten categories, three classes.
+  HoeffdingCategoricalSplit<GiniImpurity> split(10, 3);
+
+  for (size_t i = 0; i < 500; ++i)
+  {
+    split.Train(math::RandInt(0, 10), 1);
+    BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+  }
+}
+
+/**
+ * A harder majority class example.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitHarderMajorityClassTest)
+{
+  // Ten categories, three classes.
+  HoeffdingCategoricalSplit<GiniImpurity> split(10, 3);
+
+  split.Train(math::RandInt(0, 10), 1);
+  for (size_t i = 0; i < 250; ++i)
+  {
+    split.Train(math::RandInt(0, 10), 1);
+    split.Train(math::RandInt(0, 10), 2);
+    BOOST_REQUIRE_EQUAL(split.MajorityClass(), 1);
+  }
+}
+
+/**
+ * Ensure that the fitness function is positive when we pass some data that
+ * would result in an improvement if it was split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitEasyFitnessCheck)
+{
+  HoeffdingCategoricalSplit<GiniImpurity> split(5, 3);
+
+  for (size_t i = 0; i < 100; ++i)
+    split.Train(0, 0);
+  for (size_t i = 0; i < 100; ++i)
+    split.Train(1, 1);
+  for (size_t i = 0; i < 100; ++i)
+    split.Train(2, 1);
+  for (size_t i = 0; i < 100; ++i)
+    split.Train(3, 2);
+  for (size_t i = 0; i < 100; ++i)
+    split.Train(4, 2);
+
+  BOOST_REQUIRE_GT(split.EvaluateFitnessFunction(), 0.0);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list