[mlpack-git] master: Implement the Gini impurity. (8d33d0a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:55 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 8d33d0a3765167b0917f6d4fa0c2c509f5aac981
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Sep 21 12:29:36 2015 +0000

    Implement the Gini impurity.


>---------------------------------------------------------------

8d33d0a3765167b0917f6d4fa0c2c509f5aac981
 .../methods/hoeffding_trees/gini_impurity.hpp      | 74 ++++++++++++++++++++++
 1 file changed, 74 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
new file mode 100644
index 0000000..39114d9
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
@@ -0,0 +1,74 @@
+/**
+ * @file gini_impurity.hpp
+ * @author Ryan Curtin
+ *
+ * The GiniImpurity class, which is a fitness function (FitnessFunction) for
+ * streaming decision trees.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
+#define __MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+class GiniImpurity
+{
+ public:
+  static double Evaluate(const arma::Mat<size_t>& counts)
+  {
+    // We need to sum over the difference between the un-split node and the
+    // split nodes.  First we'll calculate the number of elements in each split
+    // and total.
+    size_t numElem = 0;
+    arma::vec splitCounts(counts.n_cols);
+    for (size_t i = 0; i < counts.n_cols; ++i)
+    {
+      splitCounts[i] = arma::accu(counts.col(i));
+      numElem += splitCounts[i];
+    }
+
+    // Corner case: if there are no elements, the impurity is zero.
+    if (numElem == 0)
+      return 0.0;
+
+    arma::Col<size_t> classCounts = arma::sum(counts, 1);
+
+    // Calculate the Gini impurity of the un-split node.
+    double impurity = 0.0;
+    if (numElem > 0)
+    {
+      for (size_t i = 0; i < classCounts.n_elem; ++i)
+      {
+        const double f = ((double) classCounts[i] / (double) numElem);
+        impurity += f * (1.0 - f);
+      }
+    }
+
+    // Now calculate the impurity of the split nodes and subtract them from the
+    // overall impurity.
+    for (size_t i = 0; i < counts.n_cols; ++i)
+    {
+      if (splitCounts[i] > 0)
+      {
+        double splitImpurity = 0.0;
+        for (size_t j = 0; j < counts.n_rows; ++j)
+        {
+          const double f = ((double) counts(j, i) / (double) splitCounts[i]);
+          splitImpurity += f * (1.0 - f);
+        }
+
+        impurity -= ((double) splitCounts[i] / (double) numElem) *
+            splitImpurity;
+      }
+    }
+
+    return impurity;
+  }
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list