[mlpack-git] master: Test StreamingDecisionTree::Serialize(). (68fe2da)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:48 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 68fe2dac616efcfd5d06fd3396f80859d9ddb922
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 1 20:05:11 2015 -0400

    Test StreamingDecisionTree::Serialize().


>---------------------------------------------------------------

68fe2dac616efcfd5d06fd3396f80859d9ddb922
 src/mlpack/tests/serialization_test.cpp | 177 ++++++++++++++++++++++++++++++++
 1 file changed, 177 insertions(+)

diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index d0f04f3..1c15731 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,6 +22,7 @@
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
+#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -54,6 +55,7 @@ void TestArmadilloSerialization(MatType& x)
   }
   catch (archive_exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
 
@@ -72,6 +74,7 @@ void TestArmadilloSerialization(MatType& x)
   }
   catch (archive_exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
 
@@ -189,6 +192,7 @@ void SerializeObject(T& t, T& newT)
   }
   catch (archive_exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
   ofs.close();
@@ -204,6 +208,7 @@ void SerializeObject(T& t, T& newT)
   }
   catch (archive_exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
   ifs.close();
@@ -234,6 +239,7 @@ void SerializePointerObject(T* t, T*& newT)
   }
   catch (archive_exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
   ofs.close();
@@ -249,6 +255,7 @@ void SerializePointerObject(T* t, T*& newT)
   }
   catch (std::exception& e)
   {
+    std::cerr << e.what();
     success = false;
   }
   ifs.close();
@@ -885,4 +892,174 @@ BOOST_AUTO_TEST_CASE(HoeffdingSplitTest)
   BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
 }
 
+/**
+ * Make sure the HoeffdingSplit object serializes correctly after a split has
+ * occurred.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitAfterSplitTest)
+{
+  // Force the split to split.
+  data::DatasetInfo info;
+  info.MapString("cat0", 0);
+  info.MapString("cat1", 0);
+  info.MapString("cat0", 1);
+
+  HoeffdingSplit<> split(2, 2, info, 0.95, 5000);
+
+  // Feed samples from each class.
+  for (size_t i = 0; i < 500; ++i)
+  {
+    split.Train(arma::Col<size_t>("0 0"), 0);
+    split.Train(arma::Col<size_t>("1 0"), 1);
+  }
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 2);
+
+  data::DatasetInfo wrongInfo;
+  wrongInfo.MapString("1", 1);
+  HoeffdingSplit<> xmlSplit(3, 7, wrongInfo, 0.1, 10);
+
+  data::DatasetInfo binaryInfo;
+  binaryInfo.MapString("0", 2); // Dimension 1 is categorical.
+  binaryInfo.MapString("1", 2);
+  HoeffdingSplit<> binarySplit(5, 2, binaryInfo, 0.99, 15000);
+
+  // Train for 2 samples.
+  binarySplit.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
+  binarySplit.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
+
+  HoeffdingSplit<> textSplit(10, 11, wrongInfo, 0.75, 1000);
+
+  SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), xmlSplit.SplitDimension());
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), binarySplit.SplitDimension());
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), textSplit.SplitDimension());
+
+  // If splitting has already happened, then SplitCheck() should return 0.
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), xmlSplit.SplitCheck());
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), binarySplit.SplitCheck());
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
+
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), xmlSplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
+
+  BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+      xmlSplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+  BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+      binarySplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+  BOOST_REQUIRE_EQUAL(split.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")),
+      textSplit.CalculateDirection(arma::vec("0.3 0.4 1 0.6 0.7")));
+}
+
+BOOST_AUTO_TEST_CASE(EmptyStreamingDecisionTreeTest)
+{
+  using namespace mlpack::tree;
+
+  data::DatasetInfo info;
+  StreamingDecisionTree<HoeffdingSplit<>> tree(info, 2, 2);
+  StreamingDecisionTree<HoeffdingSplit<>> xmlTree(info, 3, 3);
+  StreamingDecisionTree<HoeffdingSplit<>> binaryTree(info, 4, 4);
+  StreamingDecisionTree<HoeffdingSplit<>> textTree(info, 5, 5);
+
+  SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
+
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(xmlTree.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(binaryTree.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(textTree.NumChildren(), 0);
+}
+
+/**
+ * Build a Hoeffding tree, then save it and make sure other trees can classify
+ * as effectively.
+ */
+BOOST_AUTO_TEST_CASE(StreamingDecisionTreeTest)
+{
+  using namespace mlpack::tree;
+
+  arma::mat dataset(2, 400);
+  arma::Row<size_t> labels(400);
+  for (size_t i = 0; i < 200; ++i)
+  {
+    dataset(0, 2 * i) = mlpack::math::RandInt(4);
+    dataset(1, 2 * i) = mlpack::math::RandInt(2);
+    dataset(0, 2 * i + 1) = mlpack::math::RandInt(4);
+    dataset(1, 2 * i + 1) = mlpack::math::RandInt(2) + 2;
+    labels[2 * i] = 0;
+    labels[2 * i + 1] = 1;
+  }
+  // Make the features categorical.
+  data::DatasetInfo info;
+  info.MapString("a", 0);
+  info.MapString("b", 0);
+  info.MapString("c", 0);
+  info.MapString("d", 0);
+  info.MapString("a", 1);
+  info.MapString("b", 1);
+  info.MapString("c", 1);
+  info.MapString("d", 1);
+
+  StreamingDecisionTree<HoeffdingSplit<>> tree(dataset, info, labels, 2);
+
+  StreamingDecisionTree<HoeffdingSplit<>> xmlTree(info, 1, 1);
+  StreamingDecisionTree<HoeffdingSplit<>> binaryTree(info, 5, 6);
+  StreamingDecisionTree<HoeffdingSplit<>> textTree(info, 7, 100);
+
+  SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren());
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
+
+  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+      xmlTree.Split().SplitDimension());
+  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+      textTree.Split().SplitDimension());
+  BOOST_REQUIRE_EQUAL(tree.Split().SplitDimension(),
+      binaryTree.Split().SplitDimension());
+
+  for (size_t i = 0; i < tree.NumChildren(); ++i)
+  {
+    BOOST_REQUIRE_EQUAL(tree.Child(i).NumChildren(), 0);
+    BOOST_REQUIRE_EQUAL(xmlTree.Child(i).NumChildren(), 0);
+    BOOST_REQUIRE_EQUAL(binaryTree.Child(i).NumChildren(), 0);
+    BOOST_REQUIRE_EQUAL(textTree.Child(i).NumChildren(), 0);
+
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+        xmlTree.Child(i).Split().SplitDimension());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+        textTree.Child(i).Split().SplitDimension());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().SplitDimension(),
+        binaryTree.Child(i).Split().SplitDimension());
+
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+        xmlTree.Child(i).Split().MajorityClass());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+        textTree.Child(i).Split().MajorityClass());
+    BOOST_REQUIRE_EQUAL(tree.Child(i).Split().MajorityClass(),
+        binaryTree.Child(i).Split().MajorityClass());
+  }
+
+  // Check that predictions are the same.
+  arma::Row<size_t> predictions, xmlPredictions, binaryPredictions, 
+      textPredictions;
+  tree.Classify(dataset, predictions);
+  xmlTree.Classify(dataset, xmlPredictions);
+  binaryTree.Classify(dataset, binaryPredictions);
+  textTree.Classify(dataset, textPredictions);
+
+  BOOST_REQUIRE_EQUAL(predictions.n_elem, xmlPredictions.n_elem);
+  BOOST_REQUIRE_EQUAL(predictions.n_elem, textPredictions.n_elem);
+  BOOST_REQUIRE_EQUAL(predictions.n_elem, binaryPredictions.n_elem);
+
+  for (size_t i = 0; i < predictions.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(predictions[i], xmlPredictions[i]);
+    BOOST_REQUIRE_EQUAL(predictions[i], textPredictions[i]);
+    BOOST_REQUIRE_EQUAL(predictions[i], binaryPredictions[i]);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list