[mlpack-git] master: Add test for binary numeric split. (08478f6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:17 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 08478f604e4faa66670c13ef99c8faff7fcd0f41
Author: ryan <ryan at ratml.org>
Date: Thu Oct 8 15:44:31 2015 -0400
Add test for binary numeric split.
>---------------------------------------------------------------
08478f604e4faa66670c13ef99c8faff7fcd0f41
src/mlpack/tests/hoeffding_tree_test.cpp | 140 +++++++++++++++++++++++++++++++
1 file changed, 140 insertions(+)
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 6d32826..7ffe083 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -592,4 +592,144 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 2);
}
+/**
+ * Create a StreamingDecisionTree that uses the HoeffdingNumericSplit and make
+ * sure it can split meaningfully on the correct dimension.
+ */
+BOOST_AUTO_TEST_CASE(NumericHoeffdingTreeTest)
+{
+ // Generate data.
+ arma::mat dataset(3, 9000);
+ arma::Row<size_t> labels(9000);
+ data::DatasetInfo info; // All features are numeric.
+ for (size_t i = 0; i < 9000; i += 3)
+ {
+ dataset(0, i) = mlpack::math::Random();
+ dataset(1, i) = mlpack::math::Random();
+ dataset(2, i) = mlpack::math::Random();
+ labels[i] = 0;
+
+ dataset(0, i + 1) = mlpack::math::Random();
+ dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+ dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+ labels[i + 1] = 2;
+
+ dataset(0, i + 2) = mlpack::math::Random();
+ dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+ dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+ labels[i + 2] = 1;
+ }
+
+ // Now train two streaming decision trees; one on the whole dataset, and one
+ // on streaming data.
+ StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ HoeffdingNumericSplit<GiniImpurity>>, arma::mat>
+ batchTree(dataset, info, labels, 3);
+ StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ HoeffdingNumericSplit<GiniImpurity>>, arma::mat>
+ streamTree(info, 3, 3);
+ for (size_t i = 0; i < 9000; ++i)
+ streamTree.Train(dataset.col(i), labels[i]);
+
+ // Each tree should have at least one split.
+ BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
+ BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
+ BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+
+ // Now, classify all the points in the dataset.
+ arma::Row<size_t> batchLabels(9000);
+ arma::Row<size_t> streamLabels(9000);
+
+ streamTree.Classify(dataset, batchLabels);
+ for (size_t i = 0; i < 9000; ++i)
+ streamLabels[i] = batchTree.Classify(dataset.col(i));
+
+ size_t streamCorrect = 0;
+ size_t batchCorrect = 0;
+ for (size_t i = 0; i < 9000; ++i)
+ {
+ if (labels[i] == streamLabels[i])
+ ++streamCorrect;
+ if (labels[i] == batchLabels[i])
+ ++batchCorrect;
+ }
+
+ // 66% accuracy shouldn't be too much to ask...
+ BOOST_REQUIRE_GT(streamCorrect, 6000);
+ BOOST_REQUIRE_GT(batchCorrect, 6000);
+}
+
+/**
+ * The same as the previous test, but with the numeric binary split, and with a
+ * categorical feature.
+ */
+BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
+{
+ // Generate data.
+ arma::mat dataset(4, 9000);
+ arma::Row<size_t> labels(9000);
+ data::DatasetInfo info; // All features are numeric, except the fourth.
+ info.MapString("0", 3);
+ for (size_t i = 0; i < 9000; i += 3)
+ {
+ dataset(0, i) = mlpack::math::Random();
+ dataset(1, i) = mlpack::math::Random();
+ dataset(2, i) = mlpack::math::Random();
+ dataset(3, i) = 0.0;
+ labels[i] = 0;
+
+ dataset(0, i + 1) = mlpack::math::Random();
+ dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+ dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+ dataset(3, i + 1) = 0.0;
+ labels[i + 1] = 2;
+
+ dataset(0, i + 2) = mlpack::math::Random();
+ dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+ dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+ dataset(3, i + 2) = 0.0;
+ labels[i + 2] = 1;
+ }
+
+ // Now train two streaming decision trees; one on the whole dataset, and one
+ // on streaming data.
+ StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ BinaryNumericSplit<GiniImpurity>>, arma::mat>
+ batchTree(dataset, info, labels, 3);
+ StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ BinaryNumericSplit<GiniImpurity>>, arma::mat>
+ streamTree(info, 4, 3);
+ for (size_t i = 0; i < 9000; ++i)
+ streamTree.Train(dataset.col(i), labels[i]);
+
+ // Each tree should have at least one split.
+ BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
+ BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
+ BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+
+ // Now, classify all the points in the dataset.
+ arma::Row<size_t> batchLabels(9000);
+ arma::Row<size_t> streamLabels(9000);
+
+ streamTree.Classify(dataset, batchLabels);
+ for (size_t i = 0; i < 9000; ++i)
+ streamLabels[i] = batchTree.Classify(dataset.col(i));
+
+ size_t streamCorrect = 0;
+ size_t batchCorrect = 0;
+ for (size_t i = 0; i < 9000; ++i)
+ {
+ if (labels[i] == streamLabels[i])
+ ++streamCorrect;
+ if (labels[i] == batchLabels[i])
+ ++batchCorrect;
+ }
+
+ // Require a pretty high accuracy: 95%.
+ BOOST_REQUIRE_GT(streamCorrect, 8550);
+ BOOST_REQUIRE_GT(batchCorrect, 8550);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list