[mlpack-git] master: Test HoeffdingNumericSplit<>::Serialize(). (7534316)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:29 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 7534316439bf932d3795ecd38f244a9050d0c2ef
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 1 14:19:50 2015 -0400

    Test HoeffdingNumericSplit<>::Serialize().


>---------------------------------------------------------------

7534316439bf932d3795ecd38f244a9050d0c2ef
 src/mlpack/tests/serialization_test.cpp | 90 +++++++++++++++++++++++++++++++++
 1 file changed, 90 insertions(+)

diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index e1be4ca..f0e3367 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -21,6 +21,7 @@
 #include <mlpack/core/tree/hrectbound.hpp>
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -687,4 +688,93 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
   CheckTrees(tree, xmlTree, textTree, binaryTree);
 }
 
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
+{
+  using namespace mlpack::tree;
+
+  HoeffdingNumericSplit<GiniImpurity> split(3);
+  // Train until it bins.
+  for (size_t i = 0; i < 200; ++i)
+    split.Train(mlpack::math::Random(), mlpack::math::RandInt(3));
+
+  HoeffdingNumericSplit<GiniImpurity> xmlSplit(5);
+  HoeffdingNumericSplit<GiniImpurity> textSplit(7);
+  for (size_t i = 0; i < 200; ++i)
+    textSplit.Train(mlpack::math::Random() + 3, 0);
+  HoeffdingNumericSplit<GiniImpurity> binarySplit(2);
+
+  SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+  // Ensure that everything is the same.
+  BOOST_REQUIRE_EQUAL(split.Bins(), xmlSplit.Bins());
+  BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
+  BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
+
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+      xmlSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+      textSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+      binarySplit.EvaluateFitnessFunction(), 1e-5);
+
+  arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
+  NumericSplitInfo<double> splitInfo, xmlSplitInfo, textSplitInfo,
+      binarySplitInfo;
+
+  split.Split(children, splitInfo);
+  xmlSplit.Split(xmlChildren, xmlSplitInfo);
+  binarySplit.Split(binaryChildren, binarySplitInfo);
+  textSplit.Split(textChildren, textSplitInfo);
+
+  BOOST_REQUIRE_EQUAL(children.size(), xmlChildren.size());
+  BOOST_REQUIRE_EQUAL(children.size(), textChildren.size());
+  BOOST_REQUIRE_EQUAL(children.size(), binaryChildren.size());
+  for (size_t i = 0; i < children.size(); ++i)
+  {
+    BOOST_REQUIRE_EQUAL(children[i], xmlChildren[i]);
+    BOOST_REQUIRE_EQUAL(children[i], textChildren[i]);
+    BOOST_REQUIRE_EQUAL(children[i], binaryChildren[i]);
+  }
+
+  // Random checks.
+  for (size_t i = 0; i < 200; ++i)
+  {
+    const double random = mlpack::math::Random() * 1.5;
+    BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+                        xmlSplitInfo.CalculateDirection(random));
+    BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+                        textSplitInfo.CalculateDirection(random));
+    BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+                        binarySplitInfo.CalculateDirection(random));
+  }
+}
+
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
+{
+  using namespace mlpack::tree;
+
+  HoeffdingNumericSplit<GiniImpurity> split(3);
+  // Train but not until it bins.
+  for (size_t i = 0; i < 50; ++i)
+    split.Train(mlpack::math::Random(), mlpack::math::RandInt(3));
+
+  HoeffdingNumericSplit<GiniImpurity> xmlSplit(5);
+  HoeffdingNumericSplit<GiniImpurity> textSplit(7);
+  for (size_t i = 0; i < 200; ++i)
+    textSplit.Train(mlpack::math::Random() + 3, 0);
+  HoeffdingNumericSplit<GiniImpurity> binarySplit(2);
+
+  SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+  // Ensure that everything is the same.
+  BOOST_REQUIRE_EQUAL(split.Bins(), xmlSplit.Bins());
+  BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
+  BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
+
+  BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_SMALL(textSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_SMALL(xmlSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_SMALL(binarySplit.EvaluateFitnessFunction(), 1e-5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list