[mlpack-git] master: Fix (kind of) serialization for HoeffdingSplit. (f416d2d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:33 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit f416d2d4a6f2a30645c2f3a9e3376f78ef0a4231
Author: ryan <ryan at ratml.org>
Date: Thu Oct 1 15:58:50 2015 -0400
Fix (kind of) serialization for HoeffdingSplit.
>---------------------------------------------------------------
f416d2d4a6f2a30645c2f3a9e3376f78ef0a4231
.../hoeffding_trees/hoeffding_split_impl.hpp | 47 ++++++++++++++++++--
src/mlpack/tests/serialization_test.cpp | 51 ++++++++++++++++++++++
2 files changed, 95 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index da73f26..9642bb2 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -329,14 +329,55 @@ void HoeffdingSplit<
if (splitDimension == size_t(-1))
{
// We have not yet split. So we have to serialize the splits.
- ar & CreateNVP(numericSplits, "numericSplits");
- ar & CreateNVP(categoricalSplits, "categoricalSplits");
-
ar & CreateNVP(numSamples, "numSamples");
ar & CreateNVP(numClasses, "numClasses");
ar & CreateNVP(maxSamples, "maxSamples");
ar & CreateNVP(successProbability, "successProbability");
+ // This is hackish for now...
+ if (Archive::is_loading::value)
+ {
+ size_t numNumeric;
+ ar & CreateNVP(numNumeric, "numNumericSplits");
+ numericSplits.resize(numNumeric, NumericSplitType(numClasses));
+ for (size_t i = 0; i < numNumeric; ++i)
+ {
+ std::ostringstream name;
+ name << "numericSplit" << i;
+ ar & CreateNVP(numericSplits[i], name.str());
+ }
+
+ size_t numCategorical;
+ ar & CreateNVP(numCategorical, "numCategoricalSplits");
+ categoricalSplits.resize(numCategorical, CategoricalSplitType(1, 1));
+ for (size_t i = 0; i < numCategorical; ++i)
+ {
+ std::ostringstream name;
+ name << "categoricalSplit" << i;
+ ar & CreateNVP(categoricalSplits[i], name.str());
+ }
+ }
+ else
+ {
+ size_t splits = numericSplits.size();
+ ar & CreateNVP(splits, "numNumericSplits");
+ for (size_t i = 0; i < numericSplits.size(); ++i)
+ {
+ std::ostringstream name;
+ name << "numericSplit" << i;
+ ar & CreateNVP(numericSplits[i], name.str());
+ }
+
+ splits = categoricalSplits.size();
+ ar & CreateNVP(splits, "numCategoricalSplits");
+ for (size_t i = 0; i < categoricalSplits.size(); ++i)
+ {
+ std::ostringstream name;
+ name << "categoricalSplit" << i;
+ ar & CreateNVP(categoricalSplits[i], name.str());
+ }
+ }
+
if (Archive::is_loading::value)
{
// Clear things we don't need.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 7db223d..d0f04f3 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -834,4 +834,55 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
}
}
+/**
+ * Make sure the HoeffdingSplit object serializes correctly before a split has
+ * occured.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitTest)
+{
+ data::DatasetInfo info;
+ info.MapString("0", 2); // Dimension 1 is categorical.
+ info.MapString("1", 2);
+ HoeffdingSplit<> split(5, 2, info, 0.99, 15000);
+
+ // Train for 2 samples.
+ split.Train(arma::vec("0.3 0.4 1 0.6 0.7"), 0);
+ split.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
+
+ data::DatasetInfo wrongInfo;
+ wrongInfo.MapString("1", 1);
+ HoeffdingSplit<> xmlSplit(3, 7, wrongInfo, 0.1, 10);
+
+ // Force the binarySplit to split.
+ data::DatasetInfo binaryInfo;
+ binaryInfo.MapString("cat0", 0);
+ binaryInfo.MapString("cat1", 0);
+ binaryInfo.MapString("cat0", 1);
+
+ HoeffdingSplit<> binarySplit(2, 2, info, 0.95, 5000);
+
+ // Feed samples from each class.
+ for (size_t i = 0; i < 500; ++i)
+ {
+ binarySplit.Train(arma::Col<size_t>("0 0"), 0);
+ binarySplit.Train(arma::Col<size_t>("1 0"), 1);
+ }
+
+ HoeffdingSplit<> textSplit(10, 11, wrongInfo, 0.75, 1000);
+
+ SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), xmlSplit.SplitDimension());
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), binarySplit.SplitDimension());
+ BOOST_REQUIRE_EQUAL(split.SplitDimension(), textSplit.SplitDimension());
+
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), xmlSplit.MajorityClass());
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
+ BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
+
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), xmlSplit.SplitCheck());
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), binarySplit.SplitCheck());
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), textSplit.SplitCheck());
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list