[mlpack-git] master: Refactor tests for new HoeffdingTree API. (5ec6e31)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:45:47 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 5ec6e311bfefed8ac0735f9b4094b6a6721ce8d0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Nov 1 17:47:12 2015 +0000

    Refactor tests for new HoeffdingTree API.


>---------------------------------------------------------------

5ec6e311bfefed8ac0735f9b4094b6a6721ce8d0
 src/mlpack/tests/hoeffding_tree_test.cpp | 70 ++++++++++++++---------------
 src/mlpack/tests/serialization_test.cpp  | 77 +++++++++++++++-----------------
 2 files changed, 71 insertions(+), 76 deletions(-)

diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 286b81e..3fcdc9d 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -5,7 +5,6 @@
  * Test file for Hoeffding trees.
  */
 #include <mlpack/core.hpp>
-#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/gini_impurity.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp>
@@ -207,7 +206,6 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
   HoeffdingCategoricalSplit<GiniImpurity> split(3, 3); // 3 categories.
 
   // No training is necessary because we can just call CreateChildren().
-  std::vector<StreamingDecisionTree<HoeffdingTree<>>> children;
   data::DatasetInfo info(3);
   info.MapString("hello", 0); // Make dimension 0 categorical.
   HoeffdingCategoricalSplit<GiniImpurity>::SplitInfo splitInfo(3);
@@ -240,7 +238,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeNoSplitTest)
   info.MapString("cat1", 2);
   info.MapString("cat2", 2);
 
-  HoeffdingTree<> split(3, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
   // Feed it samples.
   for (size_t i = 0; i < 1000; ++i)
@@ -271,18 +269,18 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEasySplitTest)
   info.MapString("cat1", 0);
   info.MapString("cat0", 1);
 
-  HoeffdingTree<> split(2, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> tree(info, 2, 0.95, 5000, 5000 /* never check for splits */);
 
   // Feed samples from each class.
   for (size_t i = 0; i < 500; ++i)
   {
-    split.Train(arma::Col<size_t>("0 0"), 0);
-    split.Train(arma::Col<size_t>("1 0"), 1);
+    tree.Train(arma::Col<size_t>("0 0"), 0);
+    tree.Train(arma::Col<size_t>("1 0"), 1);
   }
 
   // Now it should be ready to split.
-  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
-  BOOST_REQUIRE_EQUAL(split.SplitDimension(), 0);
+  BOOST_REQUIRE_EQUAL(tree.SplitCheck(), 2);
+  BOOST_REQUIRE_EQUAL(tree.SplitDimension(), 0);
 }
 
 /**
@@ -299,7 +297,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeProbability1SplitTest)
   info.MapString("cat1", 0);
   info.MapString("cat0", 1);
 
-  HoeffdingTree<> split(2, 2, info, 1.0, 12000, 1);
+  HoeffdingTree<> split(info, 2, 1.0, 12000, 1 /* always check for splits */);
 
   // Feed samples from each class.
   for (size_t i = 0; i < 5000; ++i)
@@ -327,7 +325,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAlmostPerfectSplit)
   info.MapString("cat0", 1);
   info.MapString("cat1", 1);
 
-  HoeffdingTree<> split(2, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> split(info, 2, 0.95, 5000, 5000 /* never check for splits */);
 
   // Feed samples.
   for (size_t i = 0; i < 500; ++i)
@@ -362,7 +360,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEqualSplitTest)
   info.MapString("cat0", 1);
   info.MapString("cat1", 1);
 
-  HoeffdingTree<> split(2, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
   // Feed samples.
   for (size_t i = 0; i < 500; ++i)
@@ -375,12 +373,17 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEqualSplitTest)
   BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
 }
 
+// This is used in the next test.
+template<typename FitnessFunction>
+using HoeffdingSizeTNumericSplit = HoeffdingNumericSplit<FitnessFunction,
+    size_t>;
+
 /**
  * Build a decision tree on a dataset with two meaningless dimensions and ensure
  * that it can properly classify all of the training points.  (The dataset is
  * perfectly separable.)
  */
-BOOST_AUTO_TEST_CASE(StreamingDecisionTreeSimpleDatasetTest)
+BOOST_AUTO_TEST_CASE(HoeffdingTreeSimpleDatasetTest)
 {
   DatasetInfo info(3);
   info.MapString("cat0", 0);
@@ -419,20 +422,18 @@ BOOST_AUTO_TEST_CASE(StreamingDecisionTreeSimpleDatasetTest)
 
   // Now train two streaming decision trees; one on the whole dataset, and one
   // on streaming data.
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity,
-      HoeffdingDoubleNumericSplit>, arma::Mat<size_t>>
-      batchTree(dataset, info, labels, 3);
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity,
-      HoeffdingDoubleNumericSplit>, arma::Mat<size_t>>
-      streamTree(info, 3, 3);
+  typedef HoeffdingTree<GiniImpurity, HoeffdingSizeTNumericSplit,
+      HoeffdingCategoricalSplit> TreeType;
+  TreeType batchTree(dataset, info, labels, 3, false);
+  TreeType streamTree(info, 3, 3);
   for (size_t i = 0; i < 9000; ++i)
     streamTree.Train(dataset.col(i), labels[i]);
 
   // Each tree should have a single split.
   BOOST_REQUIRE_EQUAL(batchTree.NumChildren(), 3);
   BOOST_REQUIRE_EQUAL(streamTree.NumChildren(), 3);
-  BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
-  BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(batchTree.SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(streamTree.SplitDimension(), 1);
 
   // Now, classify all the points in the dataset.
   arma::Row<size_t> batchLabels(9000);
@@ -593,8 +594,8 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
 }
 
 /**
- * Create a StreamingDecisionTree that uses the HoeffdingNumericSplit and make
- * sure it can split meaningfully on the correct dimension.
+ * Create a HoeffdingTree that uses the HoeffdingNumericSplit and make sure it
+ * can split meaningfully on the correct dimension.
  */
 BOOST_AUTO_TEST_CASE(NumericHoeffdingTreeTest)
 {
@@ -622,19 +623,17 @@ BOOST_AUTO_TEST_CASE(NumericHoeffdingTreeTest)
 
   // Now train two streaming decision trees; one on the whole dataset, and one
   // on streaming data.
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity,
-      HoeffdingDoubleNumericSplit>, arma::mat> batchTree(dataset, info, labels,
-      3);
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity,
-      HoeffdingDoubleNumericSplit>, arma::mat> streamTree(info, 3, 3);
+  typedef HoeffdingTree<GiniImpurity, HoeffdingDoubleNumericSplit> TreeType;
+  TreeType batchTree(dataset, info, labels, 3, false);
+  TreeType streamTree(info, 3, 3);
   for (size_t i = 0; i < 9000; ++i)
     streamTree.Train(dataset.col(i), labels[i]);
 
   // Each tree should have at least one split.
   BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
   BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
-  BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
-  BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(batchTree.SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(streamTree.SplitDimension(), 1);
 
   // Now, classify all the points in the dataset.
   arma::Row<size_t> batchLabels(9000);
@@ -693,18 +692,17 @@ BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
 
   // Now train two streaming decision trees; one on the whole dataset, and one
   // on streaming data.
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity, BinaryDoubleNumericSplit>,
-      arma::mat> batchTree(dataset, info, labels, 3);
-  StreamingDecisionTree<HoeffdingTree<GiniImpurity, BinaryDoubleNumericSplit>,
-      arma::mat> streamTree(info, 4, 3);
+  typedef HoeffdingTree<GiniImpurity, BinaryDoubleNumericSplit> TreeType;
+  TreeType batchTree(dataset, info, labels, 3, false);
+  TreeType streamTree(info, 4, 3);
   for (size_t i = 0; i < 9000; ++i)
     streamTree.Train(dataset.col(i), labels[i]);
 
   // Each tree should have at least one split.
   BOOST_REQUIRE_GT(batchTree.NumChildren(), 0);
   BOOST_REQUIRE_GT(streamTree.NumChildren(), 0);
-  BOOST_REQUIRE_EQUAL(batchTree.Split().SplitDimension(), 1);
-  BOOST_REQUIRE_EQUAL(streamTree.Split().SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(batchTree.SplitDimension(), 1);
+  BOOST_REQUIRE_EQUAL(streamTree.SplitDimension(), 1);
 
   // Now, classify all the points in the dataset.
   arma::Row<size_t> batchLabels(9000);
@@ -735,7 +733,7 @@ BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
 BOOST_AUTO_TEST_CASE(MajorityProbabilityTest)
 {
   data::DatasetInfo info(1);
-  StreamingDecisionTree<HoeffdingTree<>> tree(info, 1, 3);
+  HoeffdingTree<> tree(info, 3);
 
   // Feed the tree a few samples.
   tree.Train(arma::vec("1"), 0);
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 54b6af1..385331d 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,7 +22,6 @@
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
-#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -845,12 +844,12 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
  * Make sure the HoeffdingTree object serializes correctly before a split has
  * occured.
  */
-BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
+BOOST_AUTO_TEST_CASE(HoeffdingTreeBeforeSplitTest)
 {
   data::DatasetInfo info(5);
   info.MapString("0", 2); // Dimension 1 is categorical.
   info.MapString("1", 2);
-  HoeffdingTree<> split(5, 2, info, 0.99, 15000, 1);
+  HoeffdingTree<> split(info, 2, 0.99, 15000, 1);
 
   // Train for 2 samples.
   split.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
@@ -858,7 +857,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
 
   data::DatasetInfo wrongInfo(3);
   wrongInfo.MapString("1", 1);
-  HoeffdingTree<> xmlSplit(3, 7, wrongInfo, 0.1, 10, 1);
+  HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
 
   // Force the binarySplit to split.
   data::DatasetInfo binaryInfo(2);
@@ -866,7 +865,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
   binaryInfo.MapString("cat1", 0);
   binaryInfo.MapString("cat0", 1);
 
-  HoeffdingTree<> binarySplit(2, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> binarySplit(info, 2, 0.95, 5000, 1);
 
   // Feed samples from each class.
   for (size_t i = 0; i < 500; ++i)
@@ -875,7 +874,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
     binarySplit.Train(arma::Col<size_t>("1 0"), 1);
   }
 
-  HoeffdingTree<> textSplit(3, 11, wrongInfo, 0.75, 1000, 1);
+  HoeffdingTree<> textSplit(wrongInfo, 11, 0.75, 1000, 1);
 
   SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
 
@@ -904,7 +903,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
   info.MapString("cat1", 0);
   info.MapString("cat0", 1);
 
-  HoeffdingTree<> split(2, 2, info, 0.95, 5000, 1);
+  HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
   // Feed samples from each class.
   for (size_t i = 0; i < 500; ++i)
@@ -912,22 +911,23 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
     split.Train(arma::Col<size_t>("0 0"), 0);
     split.Train(arma::Col<size_t>("1 0"), 1);
   }
-  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+  // Ensure a split has happened.
+  BOOST_REQUIRE_NE(split.SplitDimension(), size_t(-1));
 
   data::DatasetInfo wrongInfo(3);
   wrongInfo.MapString("1", 1);
-  HoeffdingTree<> xmlSplit(3, 7, wrongInfo, 0.1, 10, 1);
+  HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
 
   data::DatasetInfo binaryInfo(5);
   binaryInfo.MapString("0", 2); // Dimension 2 is categorical.
   binaryInfo.MapString("1", 2);
-  HoeffdingTree<> binarySplit(5, 2, binaryInfo, 0.99, 15000, 1);
+  HoeffdingTree<> binarySplit(binaryInfo, 2, 0.99, 15000, 1);
 
   // Train for 2 samples.
   binarySplit.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
   binarySplit.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
 
-  HoeffdingTree<> textSplit(3, 11, wrongInfo, 0.75, 1000, 1);
+  HoeffdingTree<> textSplit(wrongInfo, 11, 0.75, 1000, 1);
 
   SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
 
@@ -953,15 +953,15 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
       textSplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
 }
 
-BOOST_AUTO_TEST_CASE(EmptyStreamingDecisionTreeTest)
+BOOST_AUTO_TEST_CASE(EmptyHoeffdingTreeTest)
 {
   using namespace mlpack::tree;
 
   data::DatasetInfo info(6);
-  StreamingDecisionTree<HoeffdingTree<>> tree(info, 2, 2);
-  StreamingDecisionTree<HoeffdingTree<>> xmlTree(info, 3, 3);
-  StreamingDecisionTree<HoeffdingTree<>> binaryTree(info, 4, 4);
-  StreamingDecisionTree<HoeffdingTree<>> textTree(info, 5, 5);
+  HoeffdingTree<> tree(info, 2);
+  HoeffdingTree<> xmlTree(info, 3);
+  HoeffdingTree<> binaryTree(info, 4);
+  HoeffdingTree<> textTree(info, 5);
 
   SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
 
@@ -975,7 +975,7 @@ BOOST_AUTO_TEST_CASE(EmptyStreamingDecisionTreeTest)
  * Build a Hoeffding tree, then save it and make sure other trees can classify
  * as effectively.
  */
-BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
+BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
 {
   using namespace mlpack::tree;
 
@@ -1001,14 +1001,14 @@ BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
   info.MapString("c", 1);
   info.MapString("d", 1);
 
-  StreamingDecisionTree<HoeffdingTree<>> tree(dataset, info, labels, 2);
+  HoeffdingTree<> tree(dataset, info, labels, 2, false /* no batch mode */);
 
   data::DatasetInfo xmlInfo(1);
-  StreamingDecisionTree<HoeffdingTree<>> xmlTree(xmlInfo, 1, 1);
+  HoeffdingTree<> xmlTree(xmlInfo, 1);
   data::DatasetInfo binaryInfo(5);
-  StreamingDecisionTree<HoeffdingTree<>> binaryTree(binaryInfo, 5, 6);
+  HoeffdingTree<> binaryTree(binaryInfo, 6);
   data::DatasetInfo textInfo(7);
-  StreamingDecisionTree<HoeffdingTree<>> textTree(textInfo, 7, 100);
+  HoeffdingTree<> textTree(textInfo, 100);
 
   SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
 
@@ -1016,12 +1016,9 @@ BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
   BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
   BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
 
-  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
-      xmlTree.Split().SplitDimension());
-  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
-      textTree.Split().SplitDimension());
-  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
-      binaryTree.Split().SplitDimension());
+  BOOST_REQUIRE_EQUAL(tree.SplitDimension(), xmlTree.SplitDimension());
+  BOOST_REQUIRE_EQUAL(tree.SplitDimension(), textTree.SplitDimension());
+  BOOST_REQUIRE_EQUAL(tree.SplitDimension(), binaryTree.SplitDimension());
 
   for (size_t i = 0; i < tree.NumChildren(); ++i)
   {
@@ -1030,19 +1027,19 @@ BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
     BOOST_REQUIRE_EQUAL(binaryTree.Child(i).NumChildren(), 0);
     BOOST_REQUIRE_EQUAL(textTree.Child(i).NumChildren(), 0);
 
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
-        xmlTree.Child(i).Split().SplitDimension());
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
-        textTree.Child(i).Split().SplitDimension());
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
-        binaryTree.Child(i).Split().SplitDimension());
-
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
-        xmlTree.Child(i).Split().MajorityClass());
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
-        textTree.Child(i).Split().MajorityClass());
-    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
-        binaryTree.Child(i).Split().MajorityClass());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).SplitDimension(),
+        xmlTree.Child(i).SplitDimension());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).SplitDimension(),
+        textTree.Child(i).SplitDimension());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).SplitDimension(),
+        binaryTree.Child(i).SplitDimension());
+
+    BOOST_REQUIRE_EQUAL(tree.Child(i).MajorityClass(),
+        xmlTree.Child(i).MajorityClass());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).MajorityClass(),
+        textTree.Child(i).MajorityClass());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).MajorityClass(),
+        binaryTree.Child(i).MajorityClass());
   }
 
   // Check that predictions are the same.



More information about the mlpack-git mailing list