[mlpack-git] master: Clean up Evaluate() and fix test. (90d3b42)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:20 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 90d3b42d92b0df09211343ebb568d6e1a4f05083
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Sep 22 11:33:44 2015 -0700
Clean up Evaluate() and fix test.
>---------------------------------------------------------------
90d3b42d92b0df09211343ebb568d6e1a4f05083
.../methods/hoeffding_trees/gini_impurity.hpp | 9 ++---
src/mlpack/tests/hoeffding_tree_test.cpp | 40 +++++++++++++++-------
2 files changed, 31 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
index 6c295fe..2f6ab31 100644
--- a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
+++ b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
@@ -37,13 +37,10 @@ class GiniImpurity
// Calculate the Gini impurity of the un-split node.
double impurity = 0.0;
- if (numElem > 0)
+ for (size_t i = 0; i < classCounts.n_elem; ++i)
{
- for (size_t i = 0; i < classCounts.n_elem; ++i)
- {
- const double f = ((double) classCounts[i] / (double) numElem);
- impurity += f * (1.0 - f);
- }
+ const double f = ((double) classCounts[i] / (double) numElem);
+ impurity += f * (1.0 - f);
}
// Now calculate the impurity of the split nodes and subtract them from the
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 064b21c..35952de 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -71,25 +71,32 @@ BOOST_AUTO_TEST_CASE(GiniImpurityBadSplitTest)
*/
BOOST_AUTO_TEST_CASE(GiniImpurityThreeClassTest)
{
- arma::Mat<size_t> counts(4, 3);
+ arma::Mat<size_t> counts(3, 4);
counts(0, 0) = 0;
- counts(0, 1) = 0;
- counts(0, 2) = 10;
+ counts(1, 0) = 0;
+ counts(2, 0) = 10;
- counts(1, 0) = 5;
+ counts(0, 1) = 5;
counts(1, 1) = 5;
- counts(1, 2) = 0;
+ counts(2, 1) = 0;
- counts(2, 0) = 4;
- counts(2, 1) = 4;
+ counts(0, 2) = 4;
+ counts(1, 2) = 4;
counts(2, 2) = 4;
- counts(3, 0) = 8;
- counts(3, 1) = 1;
- counts(3, 2) = 1;
-
- // The Gini impurity of the whole thing is ... not yet calculated.
+ counts(0, 3) = 8;
+ counts(1, 3) = 1;
+ counts(2, 3) = 1;
+
+ // The Gini impurity of the whole thing is:
+ // (overall sum) 0.65193 -
+ // (category 0) 0.40476 * 0 -
+ // (category 1) 0.23810 * 0.5 -
+ // (category 2) 0.28571 * 0.66667 -
+ // (category 2) 0.23810 * 0.34
+ // = 0.26145
+ BOOST_REQUIRE_CLOSE(GiniImpurity::Evaluate(counts), 0.26145, 1e-3);
}
BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
@@ -101,6 +108,15 @@ BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
}
/**
+ * Test that the range of Gini impurities is correct for a handful of class
+ * sizes.
+ */
+BOOST_AUTO_TEST_CASE(GiniImpurityRangeTest)
+{
+ BOOST_REQUIRE_CLOSE(GiniImpurity::Range(0), 1, 1e-5);
+}
+
+/**
* Feed the HoeffdingCategoricalSplit class many examples, all from the same
* class, and verify that the majority class is correct.
*/
More information about the mlpack-git
mailing list