[mlpack-git] master: Add Range() to GiniImpurity, and test HoeffdingSplit. (924d086)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:18 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 924d0869393cbb5a4c4014ac4fba0b493a5ab223
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Sep 22 11:00:49 2015 -0700

    Add Range() to GiniImpurity, and test HoeffdingSplit.


>---------------------------------------------------------------

924d0869393cbb5a4c4014ac4fba0b493a5ab223
 .../methods/hoeffding_trees/gini_impurity.hpp      | 13 +++++++++
 .../methods/hoeffding_trees/hoeffding_split.hpp    |  2 +-
 .../hoeffding_trees/hoeffding_split_impl.hpp       | 18 ++++++++----
 src/mlpack/tests/hoeffding_tree_test.cpp           | 34 ++++++++++++++++++++++
 4 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
index 39114d9..6c295fe 100644
--- a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
+++ b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
@@ -66,6 +66,19 @@ class GiniImpurity
 
     return impurity;
   }
+
+  /**
+   * Return the range of the Gini impurity for the given number of classes.
+   * (That is, the difference between the maximum possible value and the minimum
+   * possible value.)
+   */
+  static double Range(const size_t numClasses)
+  {
+    // The best possible case is that only one class exists, which gives a Gini
+    // impurity of 0.  The worst possible case is that the classes are evenly
+    // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
+    return 1.0 - (1.0 / double(numClasses));
+  }
 };
 
 } // namespace tree
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index a60f06c..716e164 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -32,7 +32,7 @@ class HoeffdingSplit
   void Train(const VecType& point, const size_t label);
 
   // 0 if split should not happen; number of splits otherwise.
-  size_t SplitCheck() const;
+  size_t SplitCheck();
 
   // Return index that we should go towards.
   template<typename VecType>
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index c0d3788..35fb6d6 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -25,7 +25,8 @@ HoeffdingSplit<
     numClasses(numClasses),
     datasetInfo(datasetInfo),
     successProbability(successProbability),
-    categoricalSplit(0)
+    categoricalSplit(0),
+    splitDimension(size_t(-1))
 {
   for (size_t i = 0; i < dimensionality; ++i)
   {
@@ -73,10 +74,10 @@ size_t HoeffdingSplit<
     FitnessFunction,
     NumericSplitType,
     CategoricalSplitType
->::SplitCheck() const
+>::SplitCheck()
 {
   // Do nothing if we've already split.
-  if (splitDimension == size_t(-1))
+  if (splitDimension != size_t(-1))
     return 0;
 
   // Check the fitness of each dimension.  Then we'll use a Hoeffding bound
@@ -117,11 +118,12 @@ size_t HoeffdingSplit<
   {
     // Split!
     splitDimension = largestIndex;
-    if (datasetInfo[largestIndex].Type == data::Datatype::categorical)
+    if (datasetInfo.Type(largestIndex) == data::Datatype::categorical)
     {
       // I don't know if this should be here.
-      majorityClass = categoricalSplit[largestIndex].MajorityClass();
-      return datasetInfo[largestIndex].NumMappings();
+      std::cout << "split: largest index " << largestIndex << ".\n";
+      majorityClass = categoricalSplits[largestIndex].MajorityClass();
+      return datasetInfo.NumMappings(largestIndex);
     }
     else
     {
@@ -129,6 +131,10 @@ size_t HoeffdingSplit<
       return 0; // I have no idea what to do yet.
     }
   }
+  else
+  {
+    return 0; // Don't split.
+  }
 }
 
 template<
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 58d34d9..064b21c 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -197,4 +197,38 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
   BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(2), 2);
 }
 
+/**
+ * If we feed the HoeffdingSplit a ton of points of the same class, it should
+ * not suggest that we split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitNoSplitTest)
+{
+  // Make all dimensions categorical.
+  data::DatasetInfo info;
+  info.MapString("cat1", 0);
+  info.MapString("cat2", 0);
+  info.MapString("cat3", 0);
+  info.MapString("cat4", 0);
+  info.MapString("cat1", 1);
+  info.MapString("cat2", 1);
+  info.MapString("cat3", 1);
+  info.MapString("cat1", 2);
+  info.MapString("cat2", 2);
+
+  HoeffdingSplit<> split(3, 2, info, 0.95);
+
+  // Feed it samples.
+  for (size_t i = 0; i < 1000; ++i)
+  {
+    // Create the test point.
+    arma::Col<size_t> testPoint(3);
+    testPoint(0) = mlpack::math::RandInt(0, 4);
+    testPoint(1) = mlpack::math::RandInt(0, 3);
+    testPoint(2) = mlpack::math::RandInt(0, 2);
+    split.Train(testPoint, 0); // Always label 0.
+
+    BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list