[mlpack-git] master: Add test for HoeffdingCategoricalSplit. (5c8603b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:31 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 5c8603bbb34376a1def9b6995239a78111d16cb5
Author: ryan <ryan at ratml.org>
Date:   Thu Oct 1 14:31:44 2015 -0400

    Add test for HoeffdingCategoricalSplit.


>---------------------------------------------------------------

5c8603bbb34376a1def9b6995239a78111d16cb5
 src/mlpack/tests/serialization_test.cpp | 57 +++++++++++++++++++++++++++++++++
 1 file changed, 57 insertions(+)

diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index f0e3367..7db223d 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -688,6 +688,10 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
   CheckTrees(tree, xmlTree, textTree, binaryTree);
 }
 
+/**
+ * Test serialization of the HoeffdingNumericSplit object after binning has
+ * occured.
+ */
 BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
 {
   using namespace mlpack::tree;
@@ -749,6 +753,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
   }
 }
 
+/**
+ * Make sure serialization of the HoeffdingNumericSplit object before binning
+ * occurs is successful.
+ */
 BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
 {
   using namespace mlpack::tree;
@@ -777,4 +785,53 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
   BOOST_REQUIRE_SMALL(binarySplit.EvaluateFitnessFunction(), 1e-5);
 }
 
+/**
+ * Make sure the HoeffdingCategoricalSplit object serializes correctly.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
+{
+  using namespace mlpack::tree;
+
+  HoeffdingCategoricalSplit<GiniImpurity> split(10, 3);
+  for (size_t i = 0; i < 50; ++i)
+    split.Train(mlpack::math::RandInt(10), mlpack::math::RandInt(3));
+
+  HoeffdingCategoricalSplit<GiniImpurity> xmlSplit(3, 7);
+  HoeffdingCategoricalSplit<GiniImpurity> binarySplit(4, 11);
+  HoeffdingCategoricalSplit<GiniImpurity> textSplit(2, 2);
+  for (size_t i = 0; i < 10; ++i)
+    textSplit.Train(mlpack::math::RandInt(2), mlpack::math::RandInt(2));
+
+  SerializeObjectAll(split, xmlSplit, textSplit, binarySplit);
+
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), xmlSplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
+  BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
+
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+                      xmlSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+                      textSplit.EvaluateFitnessFunction(), 1e-5);
+  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
+                      binarySplit.EvaluateFitnessFunction(), 1e-5);
+
+  arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
+  CategoricalSplitInfo splitInfo(1); // I don't care about this.
+
+  split.Split(children, splitInfo);
+  xmlSplit.Split(xmlChildren, splitInfo);
+  binarySplit.Split(binaryChildren, splitInfo);
+  textSplit.Split(textChildren, splitInfo);
+
+  BOOST_REQUIRE_EQUAL(children.size(), xmlChildren.size());
+  BOOST_REQUIRE_EQUAL(children.size(), textChildren.size());
+  BOOST_REQUIRE_EQUAL(children.size(), binaryChildren.size());
+  for (size_t i = 0; i < children.size(); ++i)
+  {
+    BOOST_REQUIRE_EQUAL(children[i], xmlChildren[i]);
+    BOOST_REQUIRE_EQUAL(children[i], textChildren[i]);
+    BOOST_REQUIRE_EQUAL(children[i], binaryChildren[i]);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list