[mlpack-git] master: Add information gain and tests. (66592fa)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:07 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 66592fa4d7aac09839175775a094670d128ae5d0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Nov 12 11:23:40 2015 -0500

    Add information gain and tests.


>---------------------------------------------------------------

66592fa4d7aac09839175775a094670d128ae5d0
 src/mlpack/methods/hoeffding_trees/CMakeLists.txt  |   1 +
 .../methods/hoeffding_trees/information_gain.hpp   |  91 ++++++++++++++++++
 src/mlpack/tests/hoeffding_tree_test.cpp           | 102 +++++++++++++++++++++
 3 files changed, 194 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
index f457856..ecb0d8e 100644
--- a/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
+++ b/src/mlpack/methods/hoeffding_trees/CMakeLists.txt
@@ -12,6 +12,7 @@ set(SOURCES
   hoeffding_numeric_split_impl.hpp
   hoeffding_tree.hpp
   hoeffding_tree_impl.hpp
+  information_gain.hpp
   numeric_split_info.hpp
   typedef.hpp
 )
diff --git a/src/mlpack/methods/hoeffding_trees/information_gain.hpp b/src/mlpack/methods/hoeffding_trees/information_gain.hpp
new file mode 100644
index 0000000..39e017a
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/information_gain.hpp
@@ -0,0 +1,91 @@
+/**
+ * @file information_gain.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of information gain, which can be used in place of Gini
+ * impurity.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
+#define __MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
+
+namespace mlpack {
+namespace tree {
+
+class InformationGain
+{
+ public:
+  /**
+   * Given the sufficient statistics of a proposed split, calculate the
+   * information gain if that split was to be used.  The 'counts' matrix should
+   * contain the number of points in each class in each column, so the size of
+   * 'counts' is children x classes, where 'children' is the number of child
+   * nodes in the proposed split.
+   *
+   * @param counts Matrix of sufficient statistics.
+   */
+  static double Evaluate(const arma::Mat<size_t>& counts)
+  {
+    // Calculate the number of elements in the unsplit node and also in each
+    // proposed child.
+    size_t numElem = 0;
+    arma::vec splitCounts(counts.n_elem);
+    for (size_t i = 0; i < counts.n_cols; ++i)
+    {
+      splitCounts[i] = arma::accu(counts.col(i));
+      numElem += splitCounts[i];
+    }
+
+    // Corner case: if there are no elements, the gain is zero.
+    if (numElem == 0)
+      return 0.0;
+
+    arma::Col<size_t> classCounts = arma::sum(counts, 1);
+
+    // Calculate the gain of the unsplit node.
+    double gain = 0.0;
+    for (size_t i = 0; i < classCounts.n_elem; ++i)
+    {
+      const double f = ((double) classCounts[i] / (double) numElem);
+      if (f > 0.0)
+        gain -= f * std::log2(f);
+    }
+
+    // Now calculate the impurity of the split nodes and subtract them from the
+    // overall gain.
+    for (size_t i = 0; i < counts.n_cols; ++i)
+    {
+      if (splitCounts[i] > 0)
+      {
+        double splitGain = 0.0;
+        for (size_t j = 0; j < counts.n_rows; ++j)
+        {
+          const double f = ((double) counts(j, i) / (double) splitCounts[i]);
+          if (f > 0.0)
+            splitGain += f * std::log2(f);
+        }
+
+        gain += ((double) splitCounts[i] / (double) numElem) * splitGain;
+      }
+    }
+
+    return gain;
+  }
+
+  /**
+   * Return the range of the information gain for the given number of classes.
+   * (That is, the difference between the maximum possible value and the minimum
+   * possible value.)
+   */
+  static double Range(const size_t numClasses)
+  {
+    // The best possible case gives an information gain of 0.  The worst
+    // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
+    // log2(1/n) = -log2(n).  So, the range is log2(n).
+    return std::log2(numClasses);
+  }
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 72516a1..7e3fee1 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -6,6 +6,7 @@
  */
 #include <mlpack/core.hpp>
 #include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
+#include <mlpack/methods/hoeffding_trees/information_gain.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
 #include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
@@ -125,6 +126,107 @@ BOOST_AUTO_TEST_CASE(GiniImpurityRangeTest)
   BOOST_REQUIRE_CLOSE(GiniImpurity::Range(1000), 0.999, 1e-5);
 }
 
+BOOST_AUTO_TEST_CASE(InformationGainPerfectSimpleTest)
+{
+  // Make a simple test for Gini impurity with one class.  In this case it
+  // should always be 0.  We'll assemble the count matrix by hand.
+  arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+  counts(0, 0) = 10; // 10 points in category 0 with class 0.
+  counts(0, 1) = 0; // 0 points in category 0 with class 1.
+  counts(1, 0) = 12; // 12 points in category 1 with class 0.
+  counts(1, 1) = 0; // 0 points in category 1 with class 1.
+
+  // Since the split gets us nothing, there should be no gain.
+  BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainImperfectSimpleTest)
+{
+  // Make a simple test where a split will give us perfect classification.
+  arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+  counts(0, 0) = 10; // 10 points in category 0 with class 0.
+  counts(1, 0) = 0; // 0 points in category 0 with class 1.
+  counts(0, 1) = 0; // 0 points in category 1 with class 0.
+  counts(1, 1) = 10; // 10 points in category 1 with class 1.
+
+  // The impurity before the split should be 0.5 log2(0.5) + 0.5 log2(0.5) = -1.
+  // The impurity after the split should be 0.
+  // So the gain should be 1.
+  BOOST_REQUIRE_CLOSE(InformationGain::Evaluate(counts), 1.0, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainBadSplitTest)
+{
+  // Make a simple test where a split gets us nothing.
+  arma::Mat<size_t> counts(2, 2);
+  counts(0, 0) = 10;
+  counts(0, 1) = 10;
+  counts(1, 0) = 5;
+  counts(1, 1) = 5;
+
+  BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+/**
+ * A hand-crafted more difficult test for the Gini impurity, where four
+ * categories and three classes are available.
+ */
+BOOST_AUTO_TEST_CASE(InformationGainThreeClassTest)
+{
+  arma::Mat<size_t> counts(3, 4);
+
+  counts(0, 0) = 0;
+  counts(1, 0) = 0;
+  counts(2, 0) = 10;
+
+  counts(0, 1) = 5;
+  counts(1, 1) = 5;
+  counts(2, 1) = 0;
+
+  counts(0, 2) = 4;
+  counts(1, 2) = 4;
+  counts(2, 2) = 4;
+
+  counts(0, 3) = 8;
+  counts(1, 3) = 1;
+  counts(2, 3) = 1;
+
+  // The Gini impurity of the whole thing is:
+  // (overall sum) -1.5516 +
+  // (category 0)  0.40476 * 0       -
+  // (category 1)  0.23810 * -1      -
+  // (category 2)  0.28571 * -1.5850 -
+  // (category 3)  0.23810 * -0.92193
+  //   = 0.64116649
+  BOOST_REQUIRE_CLOSE(InformationGain::Evaluate(counts), 0.64116649, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(InformationGainZeroTest)
+{
+  // When nothing has been seen, the information gain should be zero.
+  arma::Mat<size_t> counts = arma::zeros<arma::Mat<size_t>>(10, 10);
+
+  BOOST_REQUIRE_SMALL(InformationGain::Evaluate(counts), 1e-10);
+}
+
+/**
+ * Test that the range of information gains is correct for a handful of class
+ * sizes.
+ */
+BOOST_AUTO_TEST_CASE(InformationGainRangeTest)
+{
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(1), 0, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(2), 1.0, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(3), 1.5849625, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(4), 2, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(5), 2.32192809, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(10), 3.32192809, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(100), 6.64385619, 1e-5);
+  BOOST_REQUIRE_CLOSE(InformationGain::Range(1000), 9.96578428, 1e-5);
+}
+
 /**
  * Feed the HoeffdingCategoricalSplit class many examples, all from the same
  * class, and verify that the majority class is correct.



More information about the mlpack-git mailing list