[mlpack-git] master: Add batch mode and a test for it. (4233cd4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:45:51 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 4233cd4c82bf306be948f9171dc37e342acf277d
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Nov 1 19:05:37 2015 +0000
Add batch mode and a test for it.
>---------------------------------------------------------------
4233cd4c82bf306be948f9171dc37e342acf277d
.../hoeffding_trees/hoeffding_tree_impl.hpp | 5 +-
src/mlpack/tests/hoeffding_tree_test.cpp | 96 ++++++++++++++++++++++
2 files changed, 98 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index 4c802df..f7c48ab 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -34,7 +34,7 @@ HoeffdingTree<
ownsMappings(true),
numSamples(0),
numClasses(numClasses),
- maxSamples(maxSamples),
+ maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
checkInterval(checkInterval),
datasetInfo(&datasetInfo),
ownsInfo(false),
@@ -84,7 +84,7 @@ HoeffdingTree<
ownsMappings(dimensionMappingsIn == NULL),
numSamples(0),
numClasses(numClasses),
- maxSamples(maxSamples),
+ maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
checkInterval(checkInterval),
datasetInfo(&datasetInfo),
ownsInfo(false),
@@ -183,7 +183,6 @@ void HoeffdingTree<
const arma::Row<size_t>& labels,
const bool batchTraining)
{
- // Not yet implemented.
if (batchTraining)
{
// Pass all the points through the nodes, and then split only after that.
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 3fcdc9d..cd443e5 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -13,6 +13,8 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include <stack>
+
using namespace std;
using namespace arma;
using namespace mlpack;
@@ -765,4 +767,98 @@ BOOST_AUTO_TEST_CASE(MajorityProbabilityTest)
BOOST_REQUIRE_CLOSE(probability, 0.625, 1e-5);
}
+/**
+ * Make sure that batch training mode outperforms non-batch mode.
+ */
+BOOST_AUTO_TEST_CASE(BatchTrainingTest)
+{
+ // We need to create a dataset with some amount of complexity, that must be
+ // split in a handful of ways to accurately classify the data. An expanding
+ // spiral should do the trick here. We'll make the spiral in two dimensions.
+ // The label will change as the index increases.
+ arma::mat spiralDataset(2, 10000);
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ // One circle every 2000 samples.
+ const double magnitude = 2.0 + (double(i) / 20000.0);
+ const double angle = (i % 20000) * (2 * M_PI);
+
+ const double x = magnitude * cos(angle);
+ const double y = magnitude * sin(angle);
+
+ spiralDataset(0, i) = x;
+ spiralDataset(1, i) = y;
+ }
+
+ arma::Row<size_t> labels(10000);
+ for (size_t i = 0; i < 2000; ++i)
+ labels[i] = 1;
+ for (size_t i = 2000; i < 4000; ++i)
+ labels[i] = 3;
+ for (size_t i = 4000; i < 6000; ++i)
+ labels[i] = 2;
+ for (size_t i = 6000; i < 8000; ++i)
+ labels[i] = 0;
+ for (size_t i = 8000; i < 10000; ++i)
+ labels[i] = 4;
+
+ // Now shuffle the dataset.
+ arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0, 9999,
+ 10000));
+ arma::mat d(2, 10000);
+ arma::Row<size_t> l(10000);
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ d.col(i) = spiralDataset.col(indices[i]);
+ l[i] = labels[indices[i]];
+ }
+
+ data::DatasetInfo info(2);
+
+ // Now build two decision trees; one in batch mode, and one in streaming mode.
+ // We need to set the confidence pretty high so that the streaming tree isn't
+ // able to have enough samples to build to the same leaves.
+ HoeffdingTree<> batchTree(d, info, l, 5, true, 0.999);
+ HoeffdingTree<> streamTree(d, info, l, 5, false, 0.999);
+
+ size_t batchNodes = 0, streamNodes = 0;
+ std::stack<HoeffdingTree<>*> queue;
+ queue.push(&batchTree);
+ while (!queue.empty())
+ {
+ ++batchNodes;
+ HoeffdingTree<>* node = queue.top();
+ queue.pop();
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ queue.push(&node->Child(i));
+ }
+ queue.push(&streamTree);
+ while (!queue.empty())
+ {
+ ++streamNodes;
+ HoeffdingTree<>* node = queue.top();
+ queue.pop();
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ queue.push(&node->Child(i));
+ }
+
+ // Ensure that the performance of the batch tree is better.
+ size_t batchCorrect = 0;
+ size_t streamCorrect = 0;
+ for (size_t i = 0; i < 10000; ++i)
+ {
+ size_t streamLabel = streamTree.Classify(spiralDataset.col(i));
+ size_t batchLabel = batchTree.Classify(spiralDataset.col(i));
+
+ if (streamLabel == labels[i])
+ ++streamCorrect;
+ if (batchLabel == labels[i])
+ ++batchCorrect;
+ }
+
+ // The batch tree must be a bit better than the stream tree. But not too
+ // much, since the accuracy is already going to be very high.
+ BOOST_REQUIRE_GT(batchCorrect, streamCorrect + 25);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list