[mlpack-git] master: Add test for binary numeric split. (08478f6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:17 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 08478f604e4faa66670c13ef99c8faff7fcd0f41
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 8 15:44:31 2015 -0400

    Add test for binary numeric split.


>---------------------------------------------------------------

08478f604e4faa66670c13ef99c8faff7fcd0f41
 src/mlpack/tests/hoeffding_tree_test.cpp | 140 +++++++++++++++++++++++++++++++
 1 file changed, 140 insertions(+)

diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 6d32826..7ffe083 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -592,4 +592,144 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
   BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 2);
 }
 
+/**
+ * Create a StreamingDecisionTree that uses the HoeffdingNumericSplit and make
+ * sure it can split meaningfully on the correct dimension.
+ */
+BOOST_AUTO_TEST_CASE(NumericHoeffdingTreeTest)
+{
+  // Generate data.
+  arma::mat dataset(3, 9000);
+  arma::Row<size_t> labels(9000);
+  data::DatasetInfo info; // All features are numeric.
+  for (size_t i = 0; i < 9000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    labels[i + 2] = 1;
+  }
+
+  // Now train two streaming decision trees; one on the whole dataset, and one
+  // on streaming data.
+  StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+      HoeffdingNumericSplit<GiniImpurity>>, arma::mat>
+      batchTree(dataset, info, labels, 3);
+  StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+      HoeffdingNumericSplit<GiniImpurity>>, arma::mat>
+      streamTree(info, 3, 3);
+  for (size_t i = 0; i < 9000; ++i)
+    streamTree.Train(dataset.col(i), labels[i]);
+
+  // Each tree should have at least one split.
+  BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
+  BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+
+  // Now, classify all the points in the dataset.
+  arma::Row<size_t> batchLabels(9000);
+  arma::Row<size_t> streamLabels(9000);
+
+  streamTree.Classify(dataset, batchLabels);
+  for (size_t i = 0; i < 9000; ++i)
+    streamLabels[i] = batchTree.Classify(dataset.col(i));
+
+  size_t streamCorrect = 0;
+  size_t batchCorrect = 0;
+  for (size_t i = 0; i < 9000; ++i)
+  {
+    if (labels[i] == streamLabels[i])
+      ++streamCorrect;
+    if (labels[i] == batchLabels[i])
+      ++batchCorrect;
+  }
+
+  // 66% accuracy shouldn't be too much to ask...
+  BOOST_REQUIRE_GT(streamCorrect, 6000);
+  BOOST_REQUIRE_GT(batchCorrect, 6000);
+}
+
+/**
+ * The same as the previous test, but with the numeric binary split, and with a
+ * categorical feature.
+ */
+BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 9000);
+  arma::Row<size_t> labels(9000);
+  data::DatasetInfo info; // All features are numeric, except the fourth.
+  info.MapString("0", 3);
+  for (size_t i = 0; i < 9000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  // Now train two streaming decision trees; one on the whole dataset, and one
+  // on streaming data.
+  StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+      BinaryNumericSplit<GiniImpurity>>, arma::mat>
+      batchTree(dataset, info, labels, 3);
+  StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+      BinaryNumericSplit<GiniImpurity>>, arma::mat>
+      streamTree(info, 4, 3);
+  for (size_t i = 0; i < 9000; ++i)
+    streamTree.Train(dataset.col(i), labels[i]);
+
+  // Each tree should have at least one split.
+  BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
+  BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+
+  // Now, classify all the points in the dataset.
+  arma::Row<size_t> batchLabels(9000);
+  arma::Row<size_t> streamLabels(9000);
+
+  streamTree.Classify(dataset, batchLabels);
+  for (size_t i = 0; i < 9000; ++i)
+    streamLabels[i] = batchTree.Classify(dataset.col(i));
+
+  size_t streamCorrect = 0;
+  size_t batchCorrect = 0;
+  for (size_t i = 0; i < 9000; ++i)
+  {
+    if (labels[i] == streamLabels[i])
+      ++streamCorrect;
+    if (labels[i] == batchLabels[i])
+      ++batchCorrect;
+  }
+
+  // Require a pretty high accuracy: 95%.
+  BOOST_REQUIRE_GT(streamCorrect, 8550);
+  BOOST_REQUIRE_GT(batchCorrect, 8550);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list