[mlpack-git] master: Clean up Evaluate() and fix test. (90d3b42)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:20 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 90d3b42d92b0df09211343ebb568d6e1a4f05083
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Sep 22 11:33:44 2015 -0700

    Clean up Evaluate() and fix test.


>---------------------------------------------------------------

90d3b42d92b0df09211343ebb568d6e1a4f05083
 .../methods/hoeffding_trees/gini_impurity.hpp      |  9 ++---
 src/mlpack/tests/hoeffding_tree_test.cpp           | 40 +++++++++++++++-------
 2 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
index 6c295fe..2f6ab31 100644
--- a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
+++ b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
@@ -37,13 +37,10 @@ class GiniImpurity
 
     // Calculate the Gini impurity of the un-split node.
     double impurity = 0.0;
-    if (numElem > 0)
+    for (size_t i = 0; i < classCounts.n_elem; ++i)
     {
-      for (size_t i = 0; i < classCounts.n_elem; ++i)
-      {
-        const double f = ((double) classCounts[i] / (double) numElem);
-        impurity += f * (1.0 - f);
-      }
+      const double f = ((double) classCounts[i] / (double) numElem);
+      impurity += f * (1.0 - f);
     }
 
     // Now calculate the impurity of the split nodes and subtract them from the
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 064b21c..35952de 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -71,25 +71,32 @@ BOOST_AUTO_TEST_CASE(GiniImpurityBadSplitTest)
  */
 BOOST_AUTO_TEST_CASE(GiniImpurityThreeClassTest)
 {
-  arma::Mat<size_t> counts(4, 3);
+  arma::Mat<size_t> counts(3, 4);
 
   counts(0, 0) = 0;
-  counts(0, 1) = 0;
-  counts(0, 2) = 10;
+  counts(1, 0) = 0;
+  counts(2, 0) = 10;
 
-  counts(1, 0) = 5;
+  counts(0, 1) = 5;
   counts(1, 1) = 5;
-  counts(1, 2) = 0;
+  counts(2, 1) = 0;
 
-  counts(2, 0) = 4;
-  counts(2, 1) = 4;
+  counts(0, 2) = 4;
+  counts(1, 2) = 4;
   counts(2, 2) = 4;
 
-  counts(3, 0) = 8;
-  counts(3, 1) = 1;
-  counts(3, 2) = 1;
-
-  // The Gini impurity of the whole thing is ... not yet calculated.
+  counts(0, 3) = 8;
+  counts(1, 3) = 1;
+  counts(2, 3) = 1;
+
+  // The Gini impurity of the whole thing is:
+  // (overall sum) 0.65193 -
+  // (category 0)  0.40476 * 0       -
+  // (category 1)  0.23810 * 0.5     -
+  // (category 2)  0.28571 * 0.66667 -
+  // (category 2)  0.23810 * 0.34
+  //   = 0.26145
+  BOOST_REQUIRE_CLOSE(GiniImpurity::Evaluate(counts), 0.26145, 1e-3);
 }
 
 BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
@@ -101,6 +108,15 @@ BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
 }
 
 /**
+ * Test that the range of Gini impurities is correct for a handful of class
+ * sizes.
+ */
+BOOST_AUTO_TEST_CASE(GiniImpurityRangeTest)
+{
+  BOOST_REQUIRE_CLOSE(GiniImpurity::Range(0), 1, 1e-5);
+}
+
+/**
  * Feed the HoeffdingCategoricalSplit class many examples, all from the same
  * class, and verify that the majority class is correct.
  */



More information about the mlpack-git mailing list