[mlpack-git] master: Test HoeffdingNumericSplit<>::Serialize(). (7534316)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:29 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 7534316439bf932d3795ecd38f244a9050d0c2ef
Author: ryan <ryan at ratml.org>
Date: Thu Oct 1 14:19:50 2015 -0400
Test HoeffdingNumericSplit<>::Serialize().
>---------------------------------------------------------------
7534316439bf932d3795ecd38f244a9050d0c2ef
src/mlpack/tests/serialization_test.cpp | 90 +++++++++++++++++++++++++++++++++
1 file changed, 90 insertions(+)
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index e1be4ca..f0e3367 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -21,6 +21,7 @@
#include <mlpack/core/tree/hrectbound.hpp>
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -687,4 +688,93 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
CheckTrees(tree, xmlTree, textTree, binaryTree);
}
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
+{
+ using namespace mlpack::tree;
+
+ HoeffdingNumericSplit<GiniImpurity> split(3);
+ // Train until it bins.
+ for (size_t i = 0; i < 200; ++i)
+ split.Train(mlpack::math::Random(), mlpack::math::RandInt(3));
+
+ HoeffdingNumericSplit<GiniImpurity> xmlSplit(5);
+ HoeffdingNumericSplit<GiniImpurity> textSplit(7);
+ for (size_t i = 0; i < 200; ++i)
+ textSplit.Train(mlpack::math::Random() + 3, 0);
+ HoeffdingNumericSplit<GiniImpurity> binarySplit(2);
+
+ SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+ // Ensure that everything is the same.
+ BOOST_REQUIRE_EQUAL(split.Bins(), xmlSplit.Bins());
+ BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
+ BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
+
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+ xmlSplit.EvaluateFitnessFunction(), 1e-5);
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+ textSplit.EvaluateFitnessFunction(), 1e-5);
+ BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+ binarySplit.EvaluateFitnessFunction(), 1e-5);
+
+ arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
+ NumericSplitInfo<double> splitInfo, xmlSplitInfo, textSplitInfo,
+ binarySplitInfo;
+
+ split.Split(children, splitInfo);
+ xmlSplit.Split(xmlChildren, xmlSplitInfo);
+ binarySplit.Split(binaryChildren, binarySplitInfo);
+ textSplit.Split(textChildren, textSplitInfo);
+
+ BOOST_REQUIRE_EQUAL(children.size(), xmlChildren.size());
+ BOOST_REQUIRE_EQUAL(children.size(), textChildren.size());
+ BOOST_REQUIRE_EQUAL(children.size(), binaryChildren.size());
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ BOOST_REQUIRE_EQUAL(children[i], xmlChildren[i]);
+ BOOST_REQUIRE_EQUAL(children[i], textChildren[i]);
+ BOOST_REQUIRE_EQUAL(children[i], binaryChildren[i]);
+ }
+
+ // Random checks.
+ for (size_t i = 0; i < 200; ++i)
+ {
+ const double random = mlpack::math::Random() * 1.5;
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+ xmlSplitInfo.CalculateDirection(random));
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+ textSplitInfo.CalculateDirection(random));
+ BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(random),
+ binarySplitInfo.CalculateDirection(random));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
+{
+ using namespace mlpack::tree;
+
+ HoeffdingNumericSplit<GiniImpurity> split(3);
+ // Train but not until it bins.
+ for (size_t i = 0; i < 50; ++i)
+ split.Train(mlpack::math::Random(), mlpack::math::RandInt(3));
+
+ HoeffdingNumericSplit<GiniImpurity> xmlSplit(5);
+ HoeffdingNumericSplit<GiniImpurity> textSplit(7);
+ for (size_t i = 0; i < 200; ++i)
+ textSplit.Train(mlpack::math::Random() + 3, 0);
+ HoeffdingNumericSplit<GiniImpurity> binarySplit(2);
+
+ SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+ // Ensure that everything is the same.
+ BOOST_REQUIRE_EQUAL(split.Bins(), xmlSplit.Bins());
+ BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
+ BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
+
+ BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-5);
+ BOOST_REQUIRE_SMALL(textSplit.EvaluateFitnessFunction(), 1e-5);
+ BOOST_REQUIRE_SMALL(xmlSplit.EvaluateFitnessFunction(), 1e-5);
+ BOOST_REQUIRE_SMALL(binarySplit.EvaluateFitnessFunction(), 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list