[mlpack-git] master: Add information gain and tests. (66592fa)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:07 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 66592fa4d7aac09839175775a094670d128ae5d0
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Nov 12 11:23:40 2015 -0500
Add information gain and tests.
>---------------------------------------------------------------
66592fa4d7aac09839175775a094670d128ae5d0
src/mlpack/methods/hoeffding_trees/CMakeLists.txt | 1 +
.../methods/hoeffding_trees/information_gain.hpp | 91 ++++++++++++++++++
src/mlpack/tests/hoeffding_tree_test.cpp | 102 +++++++++++++++++++++
3 files changed, 194 insertions(+)
diff --git a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
index f457856..ecb0d8e 100644
--- a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
+++ b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
@@ -12,6 +12,7 @@ set(SOURCES
hoeffding_numeric_split_impl.hpp
hoeffding_tree.hpp
hoeffding_tree_impl.hpp
+ information_gain.hpp
numeric_split_info.hpp
typedef.hpp
)
diff --git a/src/mlpack/methods/hoeffding_trees/information_gain.hpp b/src/mlpack/methods/hoeffding_trees/information_gain.hpp
new file mode 100644
index 0000000..39e017a
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/information_gain.hpp
@@ -0,0 +1,91 @@
+/**
+ * @file information_gain.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of information gain, which can be used in place of Gini
+ * impurity.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
+#define __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
+
+namespace mlpack {
+namespace tree {
+
+class InformationGain
+{
+ public:
+ /**
+ * Given the sufficient statistics of a proposed split, calculate the
+ * information gain if that split was to be used. The 'counts' matrix should
+ * contain the number of points in each class in each column, so the size of
+ * 'counts' is children x classes, where 'children' is the number of child
+ * nodes in the proposed split.
+ *
+ * @param counts Matrix of sufficient statistics.
+ */
+ static double Evaluate(const arma::Mat<size_t>& counts)
+ {
+ // Calculate the number of elements in the unsplit node and also in each
+ // proposed child.
+ size_t numElem = 0;
+ arma::vec splitCounts(counts.n_elem);
+ for (size_t i = 0; i < counts.n_cols; ++i)
+ {
+ splitCounts[i] = arma::accu(counts.col(i));
+ numElem += splitCounts[i];
+ }
+
+ // Corner case: if there are no elements, the gain is zero.
+ if (numElem == 0)
+ return 0.0;
+
+ arma::Col<size_t> classCounts = arma::sum(counts, 1);
+
+ // Calculate the gain of the unsplit node.
+ double gain = 0.0;
+ for (size_t i = 0; i < classCounts.n_elem; ++i)
+ {
+ const double f = ((double) classCounts[i] / (double) numElem);
+ if (f > 0.0)
+ gain -= f * std::log2(f);
+ }
+
+ // Now calculate the impurity of the split nodes and subtract them from the
+ // overall gain.
+ for (size_t i = 0; i < counts.n_cols; ++i)
+ {
+ if (splitCounts[i] > 0)
+ {
+ double splitGain = 0.0;
+ for (size_t j = 0; j < counts.n_rows; ++j)
+ {
+ const double f = ((double) counts(j, i) / (double) splitCounts[i]);
+ if (f > 0.0)
+ splitGain += f * std::log2(f);
+ }
+
+ gain += ((double) splitCounts[i] / (double) numElem) * splitGain;
+ }
+ }
+
+ return gain;
+ }
+
+ /**
+ * Return the range of the information gain for the given number of classes.
+ * (That is, the difference between the maximum possible value and the minimum
+ * possible value.)
+ */
+ static double Range(const size_t numClasses)
+ {
+ // The best possible case gives an information gain of 0. The worst
+ // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
+ // log2(1/n) = -log2(n). So, the range is log2(n).
+ return std::log2(numClasses);
+ }
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 72516a1..7e3fee1 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -6,6 +6,7 @@
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
+#include <mlpack/methods/hoeffding_trees/information_gain.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
#include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
@@ -125,6 +126,107 @@ BOOST_AUTO_TEST_CASE(GiniImpurityRangeTest)
BOOST_REQUIRE_CLOSE(GiniImpurity::Range(1000), 0.999, 1e-5);
}
+BOOST_AUTO_TEST_CASE(InformationGainPerfectSimpleTest)
+{
+ // Make a simple test for Gini impurity with one class. In this case it
+ // should always be 0. We'll assemble the count matrix by hand.
+ arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+ counts(0, 0) = 10; // 10 points in category 0 with class 0.
+ counts(0, 1) = 0; // 0 points in category 0 with class 1.
+ counts(1, 0) = 12; // 12 points in category 1 with class 0.
+ counts(1, 1) = 0; // 0 points in category 1 with class 1.
+
+ // Since the split gets us nothing, there should be no gain.
+ BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainImperfectSimpleTest)
+{
+ // Make a simple test where a split will give us perfect classification.
+ arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+ counts(0, 0) = 10; // 10 points in category 0 with class 0.
+ counts(1, 0) = 0; // 0 points in category 0 with class 1.
+ counts(0, 1) = 0; // 0 points in category 1 with class 0.
+ counts(1, 1) = 10; // 10 points in category 1 with class 1.
+
+ // The impurity before the split should be 0.5 log2(0.5) + 0.5 log2(0.5) = -1.
+ // The impurity after the split should be 0.
+ // So the gain should be 1.
+ BOOST_REQUIRE_CLOSE(InformationGain::Evaluate(counts), 1.0, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainBadSplitTest)
+{
+ // Make a simple test where a split gets us nothing.
+ arma::Mat<size_t> counts(2, 2);
+ counts(0, 0) = 10;
+ counts(0, 1) = 10;
+ counts(1, 0) = 5;
+ counts(1, 1) = 5;
+
+ BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+/**
+ * A hand-crafted more difficult test for the Gini impurity, where four
+ * categories and three classes are available.
+ */
+BOOST_AUTO_TEST_CASE(InformationGainThreeClassTest)
+{
+ arma::Mat<size_t> counts(3, 4);
+
+ counts(0, 0) = 0;
+ counts(1, 0) = 0;
+ counts(2, 0) = 10;
+
+ counts(0, 1) = 5;
+ counts(1, 1) = 5;
+ counts(2, 1) = 0;
+
+ counts(0, 2) = 4;
+ counts(1, 2) = 4;
+ counts(2, 2) = 4;
+
+ counts(0, 3) = 8;
+ counts(1, 3) = 1;
+ counts(2, 3) = 1;
+
+ // The Gini impurity of the whole thing is:
+ // (overall sum) -1.5516 +
+ // (category 0) 0.40476 * 0 -
+ // (category 1) 0.23810 * -1 -
+ // (category 2) 0.28571 * -1.5850 -
+ // (category 3) 0.23810 * -0.92193
+ // = 0.64116649
+ BOOST_REQUIRE_CLOSE(InformationGain::Evaluate(counts), 0.64116649, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainZeroTest)
+{
+ // When nothing has been seen, the information gain should be zero.
+ arma::Mat<size_t> counts = arma::zeros<arma::Mat<size_t>>(10, 10);
+
+ BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+/**
+ * Test that the range of information gains is correct for a handful of class
+ * sizes.
+ */
+BOOST_AUTO_TEST_CASE(InformationGainRangeTest)
+{
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(1), 0, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(2), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(3), 1.5849625, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(4), 2, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(5), 2.32192809, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(10), 3.32192809, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(100), 6.64385619, 1e-5);
+ BOOST_REQUIRE_CLOSE(InformationGain::Range(1000), 9.96578428, 1e-5);
+}
+
/**
* Feed the HoeffdingCategoricalSplit class many examples, all from the same
* class, and verify that the majority class is correct.
More information about the mlpack-git
mailing list