[mlpack-git] master: Fix (kind of) serialization for HoeffdingSplit. (f416d2d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:33 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit f416d2d4a6f2a30645c2f3a9e3376f78ef0a4231
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 1 15:58:50 2015 -0400

    Fix (kind of) serialization for HoeffdingSplit.


>---------------------------------------------------------------

f416d2d4a6f2a30645c2f3a9e3376f78ef0a4231
 .../hoeffding_trees/hoeffding_split_impl.hpp       | 47 ++++++++++++++++++--
 src/mlpack/tests/serialization_test.cpp            | 51 ++++++++++++++++++++++
 2 files changed, 95 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index da73f26..9642bb2 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -329,14 +329,55 @@ void HoeffdingSplit<
   if (splitDimension == size_t(-1))
   {
     // We have not yet split.  So we have to serialize the splits.
-    ar & CreateNVP(numericSplits, "numericSplits");
-    ar & CreateNVP(categoricalSplits, "categoricalSplits");
-
     ar & CreateNVP(numSamples, "numSamples");
     ar & CreateNVP(numClasses, "numClasses");
     ar & CreateNVP(maxSamples, "maxSamples");
     ar & CreateNVP(successProbability, "successProbability");
 
+    // This is hackish for now...
+    if (Archive::is_loading::value)
+    {
+      size_t numNumeric;
+      ar & CreateNVP(numNumeric, "numNumericSplits");
+      numericSplits.resize(numNumeric, NumericSplitType(numClasses));
+      for (size_t i = 0; i < numNumeric; ++i)
+      {
+        std::ostringstream name;
+        name << "numericSplit" << i;
+        ar & CreateNVP(numericSplits[i], name.str());
+      }
+
+      size_t numCategorical;
+      ar & CreateNVP(numCategorical, "numCategoricalSplits");
+      categoricalSplits.resize(numCategorical, CategoricalSplitType(1, 1));
+      for (size_t i = 0; i < numCategorical; ++i)
+      {
+        std::ostringstream name;
+        name << "categoricalSplit" << i;
+        ar & CreateNVP(categoricalSplits[i], name.str());
+      }
+    }
+    else
+    {
+      size_t splits = numericSplits.size();
+      ar & CreateNVP(splits, "numNumericSplits");
+      for (size_t i = 0; i < numericSplits.size(); ++i)
+      {
+        std::ostringstream name;
+        name << "numericSplit" << i;
+        ar & CreateNVP(numericSplits[i], name.str());
+      }
+
+      splits = categoricalSplits.size();
+      ar & CreateNVP(splits, "numCategoricalSplits");
+      for (size_t i = 0; i < categoricalSplits.size(); ++i)
+      {
+        std::ostringstream name;
+        name << "categoricalSplit" << i;
+        ar & CreateNVP(categoricalSplits[i], name.str());
+      }
+    }
+
     if (Archive::is_loading::value)
     {
       // Clear things we don't need.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 7db223d..d0f04f3 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -834,4 +834,55 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
   }
 }
 
+/**
+ * Make sure the HoeffdingSplit object serializes correctly before a split has
+ * occured.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitTest)
+{
+  data::DatasetInfo info;
+  info.MapString("0", 2); // Dimension 1 is categorical.
+  info.MapString("1", 2);
+  HoeffdingSplit<> split(5, 2, info, 0.99, 15000);
+
+  // Train for 2 samples.
+  split.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
+  split.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
+
+  data::DatasetInfo wrongInfo;
+  wrongInfo.MapString("1", 1);
+  HoeffdingSplit<> xmlSplit(3, 7, wrongInfo, 0.1, 10);
+
+  // Force the binarySplit to split.
+  data::DatasetInfo binaryInfo;
+  binaryInfo.MapString("cat0", 0);
+  binaryInfo.MapString("cat1", 0);
+  binaryInfo.MapString("cat0", 1);
+
+  HoeffdingSplit<> binarySplit(2, 2, info, 0.95, 5000);
+
+  // Feed samples from each class.
+  for (size_t i = 0; i < 500; ++i)
+  {
+    binarySplit.Train(arma::Col<size_t>("0 0"), 0);
+    binarySplit.Train(arma::Col<size_t>("1 0"), 1);
+  }
+
+  HoeffdingSplit<> textSplit(10, 11, wrongInfo, 0.75, 1000);
+
+  SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), xmlSplit.SplitDimension());
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), binarySplit.SplitDimension());
+  BOOST_REQUIRE_EQUAL(split.SplitDimension(), textSplit.SplitDimension());
+
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), xmlSplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
+
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), xmlSplit.SplitCheck());
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), binarySplit.SplitCheck());
+  BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list