[mlpack-git] master: Fix BatchTrainingTest by splitting into training and test set. (a3a4656)

gitdub at mlpack.org gitdub at mlpack.org
Mon Mar 7 12:42:55 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a069a5643c9aefe5361058759448d2ef2f7a4a36...a3a46561da67b99e38536f8b5824df9603a29f53

>---------------------------------------------------------------

commit a3a46561da67b99e38536f8b5824df9603a29f53
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Mar 7 12:42:06 2016 -0500

    Fix BatchTrainingTest by splitting into training and test set.
    
    Also, allow the batch tree to be only as good as the streaming tree.  And boost
    the required confidence so that the streaming tree will create fewer nodes (and
    therefore hopefully have lower test error).


>---------------------------------------------------------------

a3a46561da67b99e38536f8b5824df9603a29f53
 src/mlpack/tests/hoeffding_tree_test.cpp | 51 ++++++++++++--------------------
 1 file changed, 19 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 367f712..ff5e147 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -902,9 +902,9 @@ BOOST_AUTO_TEST_CASE(BatchTrainingTest)
   arma::mat spiralDataset(2, 10000);
   for (size_t i = 0; i < 10000; ++i)
   {
-    // One circle every 2000 samples.
-    const double magnitude = 2.0 + (double(i) / 20000.0);
-    const double angle = (i % 20000) * (2 * M_PI);
+    // One circle every 20000 samples.  Plus some noise.
+    const double magnitude = 2.0 + (double(i) / 20000.0) + 0.5 * math::Random();
+    const double angle = (i % 20000) * (2 * M_PI) + math::Random();
 
     const double x = magnitude * cos(angle);
     const double y = magnitude * sin(angle);
@@ -936,52 +936,39 @@ BOOST_AUTO_TEST_CASE(BatchTrainingTest)
     l[i] = labels[indices[i]];
   }
 
+  // Split into a training set and a test set.
+  arma::mat trainingData = d.cols(0, 4999);
+  arma::mat testData = d.cols(5000, 9999);
+  arma::Row<size_t> trainingLabels = l.subvec(0, 4999);
+  arma::Row<size_t> testLabels = l.subvec(5000, 9999);
+
   data::DatasetInfo info(2);
 
   // Now build two decision trees; one in batch mode, and one in streaming mode.
   // We need to set the confidence pretty high so that the streaming tree isn't
   // able to have enough samples to build to the same leaves.
-  HoeffdingTree<> batchTree(d, info, l, 5, true, 0.999);
-  HoeffdingTree<> streamTree(d, info, l, 5, false, 0.999);
-
-  size_t batchNodes = 0, streamNodes = 0;
-  std::stack<HoeffdingTree<>*> queue;
-  queue.push(&batchTree);
-  while (!queue.empty())
-  {
-    ++batchNodes;
-    HoeffdingTree<>* node = queue.top();
-    queue.pop();
-    for (size_t i = 0; i < node->NumChildren(); ++i)
-      queue.push(&node->Child(i));
-  }
-  queue.push(&streamTree);
-  while (!queue.empty())
-  {
-    ++streamNodes;
-    HoeffdingTree<>* node = queue.top();
-    queue.pop();
-    for (size_t i = 0; i < node->NumChildren(); ++i)
-      queue.push(&node->Child(i));
-  }
+  HoeffdingTree<> batchTree(trainingData, info, trainingLabels, 5, true,
+      0.99999999);
+  HoeffdingTree<> streamTree(trainingLabels, info, trainingLabels, 5, false,
+      0.99999999);
 
   // Ensure that the performance of the batch tree is better.
   size_t batchCorrect = 0;
   size_t streamCorrect = 0;
-  for (size_t i = 0; i < 10000; ++i)
+  for (size_t i = 0; i < 5000; ++i)
   {
-    size_t streamLabel = streamTree.Classify(spiralDataset.col(i));
-    size_t batchLabel = batchTree.Classify(spiralDataset.col(i));
+    size_t streamLabel = streamTree.Classify(testData.col(i));
+    size_t batchLabel = batchTree.Classify(testData.col(i));
 
-    if (streamLabel == labels[i])
+    if (streamLabel == testLabels[i])
       ++streamCorrect;
-    if (batchLabel == labels[i])
+    if (batchLabel == testLabels[i])
       ++batchCorrect;
   }
 
   // The batch tree must be a bit better than the stream tree.  But not too
   // much, since the accuracy is already going to be very high.
-  BOOST_REQUIRE_GT(batchCorrect, streamCorrect);
+  BOOST_REQUIRE_GE(batchCorrect, streamCorrect);
 }
 
 // Make sure that changing the confidence properly propagates to all leaves.




More information about the mlpack-git mailing list