[mlpack-git] master: Add tests for Gini impurity. (6ffeae2)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:57 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 6ffeae208ac8f53a0b8732b7af4548fe762a421a
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Sep 21 12:29:54 2015 +0000
Add tests for Gini impurity.
>---------------------------------------------------------------
6ffeae208ac8f53a0b8732b7af4548fe762a421a
src/mlpack/methods/CMakeLists.txt | 1 +
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/hoeffding_tree_test.cpp | 100 +++++++++++++++++++++++++++++++
3 files changed, 102 insertions(+)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index e90adc0..be5dc80 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -10,6 +10,7 @@ set(DIRS
fastmks
gmm
hmm
+ hoeffding_trees
kernel_pca
kmeans
mean_shift
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index fa65a7d..bdc4d0e 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -20,6 +20,7 @@ add_executable(mlpack_test
feedforward_network_test.cpp
gmm_test.cpp
hmm_test.cpp
+ hoeffding_tree_test.cpp
kernel_test.cpp
kernel_pca_test.cpp
kernel_traits_test.cpp
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
new file mode 100644
index 0000000..2659f36
--- /dev/null
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -0,0 +1,100 @@
+/**
+ * @file hoeffding_tree_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test file for Hoeffding trees.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
+#include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::data;
+using namespace mlpack::tree;
+
+BOOST_AUTO_TEST_SUITE(HoeffdingTreeTest);
+
+BOOST_AUTO_TEST_CASE(GiniImpurityPerfectSimpleTest)
+{
+ // Make a simple test for Gini impurity with one class. In this case it
+ // should always be 0. We'll assemble the count matrix by hand.
+ arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+ counts(0, 0) = 10; // 10 points in category 0 with class 0.
+ counts(0, 1) = 0; // 0 points in category 0 with class 1.
+ counts(1, 0) = 12; // 12 points in category 1 with class 0.
+ counts(1, 1) = 0; // 0 points in category 1 with class 1.
+
+ // Since the split gets us nothing, there should be no gain.
+ BOOST_REQUIRE_SMALL(GiniImpurity::Evaluate(counts), 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(GiniImpurityImperfectSimpleTest)
+{
+ // Make a simple test where a split will give us perfect classification.
+ arma::Mat<size_t> counts(2, 2); // 2 categories, 2 classes.
+
+ counts(0, 0) = 10; // 10 points in category 0 with class 0.
+ counts(1, 0) = 0; // 0 points in category 0 with class 1.
+ counts(0, 1) = 0; // 0 points in category 1 with class 0.
+ counts(1, 1) = 10; // 10 points in category 1 with class 1.
+
+ // The impurity before the split should be 0.5^2 + 0.5^2 = 0.5.
+ // The impurity after the split should be 0.
+ // So the gain should be 0.5.
+ BOOST_REQUIRE_CLOSE(GiniImpurity::Evaluate(counts), 0.5, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(GiniImpurityBadSplitTest)
+{
+ // Make a simple test where a split gets us nothing.
+ arma::Mat<size_t> counts(2, 2);
+ counts(0, 0) = 10;
+ counts(0, 1) = 10;
+ counts(1, 0) = 5;
+ counts(1, 1) = 5;
+
+ BOOST_REQUIRE_SMALL(GiniImpurity::Evaluate(counts), 1e-10);
+}
+
+/**
+ * A hand-crafted more difficult test for the Gini impurity, where four
+ * categories and three classes are available.
+ */
+BOOST_AUTO_TEST_CASE(GiniImpurityThreeClassTest)
+{
+ arma::Mat<size_t> counts(4, 3);
+
+ counts(0, 0) = 0;
+ counts(0, 1) = 0;
+ counts(0, 2) = 10;
+
+ counts(1, 0) = 5;
+ counts(1, 1) = 5;
+ counts(1, 2) = 0;
+
+ counts(2, 0) = 4;
+ counts(2, 1) = 4;
+ counts(2, 2) = 4;
+
+ counts(3, 0) = 8;
+ counts(3, 1) = 1;
+ counts(3, 2) = 1;
+
+ // The Gini impurity of the whole thing is ... not yet calculated.
+}
+
+BOOST_AUTO_TEST_CASE(GiniImpurityZeroTest)
+{
+ // When nothing has been seen, the gini impurity should be zero.
+ arma::Mat<size_t> counts = arma::zeros<arma::Mat<size_t>>(10, 10);
+
+ BOOST_REQUIRE_SMALL(GiniImpurity::Evaluate(counts), 1e-10);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list