[mlpack-git] master: Test StreamingDecisionTree::Serialize(). (68fe2da)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:48 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 68fe2dac616efcfd5d06fd3396f80859d9ddb922
Author: ryan <ryan at ratml.org>
Date: Thu Oct 1 20:05:11 2015 -0400
Test StreamingDecisionTree::Serialize().
>---------------------------------------------------------------
68fe2dac616efcfd5d06fd3396f80859d9ddb922
src/mlpack/tests/serialization_test.cpp | 177 ++++++++++++++++++++++++++++++++
1 file changed, 177 insertions(+)
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index d0f04f3..1c15731 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,6 +22,7 @@
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
+#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -54,6 +55,7 @@ void TestArmadilloSerialization(MatType& x)
}
catch (archive_exception& e)
{
+ std::cerr << e.what();
success = false;
}
@@ -72,6 +74,7 @@ void TestArmadilloSerialization(MatType& x)
}
catch (archive_exception& e)
{
+ std::cerr << e.what();
success = false;
}
@@ -189,6 +192,7 @@ void SerializeObject(T& t, T& newT)
}
catch (archive_exception& e)
{
+ std::cerr << e.what();
success = false;
}
ofs.close();
@@ -204,6 +208,7 @@ void SerializeObject(T& t, T& newT)
}
catch (archive_exception& e)
{
+ std::cerr << e.what();
success = false;
}
ifs.close();
@@ -234,6 +239,7 @@ void SerializePointerObject(T* t, T*& newT)
}
catch (archive_exception& e)
{
+ std::cerr << e.what();
success = false;
}
ofs.close();
@@ -249,6 +255,7 @@ void SerializePointerObject(T* t, T*& newT)
}
catch (std::exception& e)
{
+ std::cerr << e.what();
success = false;
}
ifs.close();
@@ -885,4 +892,174 @@ BOOST_AUTO_TEST_CASE(HoeffdingSplitTest)
BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
}
+/**
+ * Make sure the HoeffdingSplit object serializes correctly after a split has
+ * occurred.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitAfterSplitTest)
+{
+ // Force the split to split.
+ data::DatasetInfo info;
+ info.MapString("cat0", 0);
+ info.MapString("cat1", 0);
+ info.MapString("cat0", 1);
+
+ HoeffdingSplit<> split(2, 2, info, 0.95, 5000);
+
+ // Feed samples from each class.
+ for (size_t i = 0; i < 500; ++i)
+ {
+ split.Train(arma::Col<size_t>("0 0"), 0);
+ split.Train(arma::Col<size_t>("1 0"), 1);
+ }
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+
+ data::DatasetInfo wrongInfo;
+ wrongInfo.MapString("1", 1);
+ HoeffdingSplit<> xmlSplit(3, 7, wrongInfo, 0.1, 10);
+
+ data::DatasetInfo binaryInfo;
+ binaryInfo.MapString("0", 2); // Dimension 1 is categorical.
+ binaryInfo.MapString("1", 2);
+ HoeffdingSplit<> binarySplit(5, 2, binaryInfo, 0.99, 15000);
+
+ // Train for 2 samples.
+ binarySplit.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
+ binarySplit.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
+
+ HoeffdingSplit<> textSplit(10, 11, wrongInfo, 0.75, 1000);
+
+ SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), xmlSplit.SplitDimension());
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), binarySplit.SplitDimension());
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), textSplit.SplitDimension());
+
+ // If splitting has already happened, then SplitCheck() should return 0.
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), xmlSplit.SplitCheck());
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), binarySplit.SplitCheck());
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
+
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), xmlSplit.MajorityClass());
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
+
+ BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+ xmlSplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+ BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+ binarySplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+ BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+ textSplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+}
+
+BOOST_AUTO_TEST_CASE(EmptyStreamingDecisionTreeTest)
+{
+ using namespace mlpack::tree;
+
+ data::DatasetInfo info;
+ StreamingDecisionTree<HoeffdingSplit<>> tree(info, 2, 2);
+ StreamingDecisionTree<HoeffdingSplit<>> xmlTree(info, 3, 3);
+ StreamingDecisionTree<HoeffdingSplit<>> binaryTree(info, 4, 4);
+ StreamingDecisionTree<HoeffdingSplit<>> textTree(info, 5, 5);
+
+ SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
+
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(xmlTree.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(binaryTree.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(textTree.NumChildren(), 0);
+}
+
+/**
+ * Build a Hoeffding tree, then save it and make sure other trees can classify
+ * as effectively.
+ */
+BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
+{
+ using namespace mlpack::tree;
+
+ arma::mat dataset(2, 400);
+ arma::Row<size_t> labels(400);
+ for (size_t i = 0; i < 200; ++i)
+ {
+ dataset(0, 2 * i) = mlpack::math::RandInt(4);
+ dataset(1, 2 * i) = mlpack::math::RandInt(2);
+ dataset(0, 2 * i + 1) = mlpack::math::RandInt(4);
+ dataset(1, 2 * i + 1) = mlpack::math::RandInt(2) + 2;
+ labels[2 * i] = 0;
+ labels[2 * i + 1] = 1;
+ }
+ // Make the features categorical.
+ data::DatasetInfo info;
+ info.MapString("a", 0);
+ info.MapString("b", 0);
+ info.MapString("c", 0);
+ info.MapString("d", 0);
+ info.MapString("a", 1);
+ info.MapString("b", 1);
+ info.MapString("c", 1);
+ info.MapString("d", 1);
+
+ StreamingDecisionTree<HoeffdingSplit<>> tree(dataset, info, labels, 2);
+
+ StreamingDecisionTree<HoeffdingSplit<>> xmlTree(info, 1, 1);
+ StreamingDecisionTree<HoeffdingSplit<>> binaryTree(info, 5, 6);
+ StreamingDecisionTree<HoeffdingSplit<>> textTree(info, 7, 100);
+
+ SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren());
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
+
+ BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+ xmlTree.Split().SplitDimension());
+ BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+ textTree.Split().SplitDimension());
+ BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+ binaryTree.Split().SplitDimension());
+
+ for (size_t i = 0; i < tree.NumChildren(); ++i)
+ {
+ BOOST_REQUIRE_EQUAL(tree.Child(i).NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(xmlTree.Child(i).NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(binaryTree.Child(i).NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(textTree.Child(i).NumChildren(), 0);
+
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+ xmlTree.Child(i).Split().SplitDimension());
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+ textTree.Child(i).Split().SplitDimension());
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+ binaryTree.Child(i).Split().SplitDimension());
+
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+ xmlTree.Child(i).Split().MajorityClass());
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+ textTree.Child(i).Split().MajorityClass());
+ BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+ binaryTree.Child(i).Split().MajorityClass());
+ }
+
+ // Check that predictions are the same.
+ arma::Row<size_t> predictions, xmlPredictions, binaryPredictions,
+ textPredictions;
+ tree.Classify(dataset, predictions);
+ xmlTree.Classify(dataset, xmlPredictions);
+ binaryTree.Classify(dataset, binaryPredictions);
+ textTree.Classify(dataset, textPredictions);
+
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, xmlPredictions.n_elem);
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, textPredictions.n_elem);
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, binaryPredictions.n_elem);
+
+ for (size_t i = 0; i < predictions.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(predictions[i], xmlPredictions[i]);
+ BOOST_REQUIRE_EQUAL(predictions[i], textPredictions[i]);
+ BOOST_REQUIRE_EQUAL(predictions[i], binaryPredictions[i]);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list