[mlpack-git] master: Refactor for changed EvaluateFitnessFunction(). (278a5f8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:18 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 278a5f85a43cef76d11d5784cf099f1798f9b354
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Nov 12 15:18:38 2015 -0500

    Refactor for changed EvaluateFitnessFunction().


>---------------------------------------------------------------

278a5f85a43cef76d11d5784cf099f1798f9b354
 .../hoeffding_trees/hoeffding_tree_impl.hpp        | 38 +++++++------
 src/mlpack/tests/hoeffding_tree_test.cpp           | 34 +++++++++---
 src/mlpack/tests/serialization_test.cpp            | 63 ++++++++++++++++------
 3 files changed, 97 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index c9bf774..5bc1715 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -336,32 +336,40 @@ size_t HoeffdingTree<
   const double epsilon = std::sqrt(rSquared *
       std::log(1.0 / (1.0 - successProbability)) / (2 * numSamples));
 
-  arma::vec gains(categoricalSplits.size() + numericSplits.size());
-  for (size_t i = 0; i < gains.n_elem; ++i)
+  // Find the best and second best possible splits.
+  double largest = -DBL_MAX;
+  size_t largestIndex = 0;
+  double secondLargest = -DBL_MAX;
+  for (size_t i = 0; i < categoricalSplits.size() + numericSplits.size(); ++i)
   {
     size_t type = dimensionMappings->at(i).first;
     size_t index = dimensionMappings->at(i).second;
+
+    // Some split procedures can split multiple ways, but we only care about the
+    // best two splits that can be done in every network.
+    double bestGain;
+    double secondBestGain;
     if (type == data::Datatype::categorical)
-      gains[i] = categoricalSplits[index].EvaluateFitnessFunction();
+      categoricalSplits[index].EvaluateFitnessFunction(bestGain,
+          secondBestGain);
     else if (type == data::Datatype::numeric)
-      gains[i] = numericSplits[index].EvaluateFitnessFunction();
-  }
+      numericSplits[index].EvaluateFitnessFunction(bestGain, secondBestGain);
 
-  // Now find the largest and second-largest.
-  double largest = -DBL_MAX;
-  size_t largestIndex = 0;
-  double secondLargest = -DBL_MAX;
-  for (size_t i = 0; i < gains.n_elem; ++i)
-  {
-    if (gains[i] > largest)
+    // See if these gains are better than the previous.
+    if (bestGain > largest)
     {
       secondLargest = largest;
-      largest = gains[i];
+      largest = bestGain;
       largestIndex = i;
     }
-    else if (gains[i] > secondLargest)
+    else if (bestGain > secondLargest)
+    {
+      secondLargest = bestGain;
+    }
+
+    if (secondBestGain > secondLargest)
     {
-      secondLargest = gains[i];
+      secondLargest = secondBestGain;
     }
   }
 
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 7e3fee1..584bf27 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -279,7 +279,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitEasyFitnessCheck)
   for (size_t i = 0; i < 100; ++i)
     split.Train(4, 2);
 
-  BOOST_REQUIRE_GT(split.EvaluateFitnessFunction(), 0.0);
+  double bestGain, secondBestGain;
+  split.EvaluateFitnessFunction(bestGain, secondBestGain);
+  BOOST_REQUIRE_GT(bestGain, 0.0);
+  BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
 }
 
 /**
@@ -291,7 +294,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitNoImprovementFitnessTest)
   HoeffdingCategoricalSplit<GiniImpurity> split(2, 2);
 
   // No training has yet happened, so a split would get us nothing.
-  BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+  double bestGain, secondBestGain;
+  split.EvaluateFitnessFunction(bestGain, secondBestGain);
+  BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+  BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
 
   split.Train(0, 0);
   split.Train(1, 0);
@@ -299,7 +305,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitNoImprovementFitnessTest)
   split.Train(1, 1);
 
   // Now, a split still gets us only 50% accuracy in each split bin.
-  BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+  split.EvaluateFitnessFunction(bestGain, secondBestGain);
+  BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+  BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
 }
 
 /**
@@ -567,7 +575,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitFitnessFunctionTest)
   for (size_t i = 0; i < 99; ++i)
   {
     split.Train(mlpack::math::Random(), mlpack::math::RandInt(5));
-    BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+    double bestGain, secondBestGain;
+    split.EvaluateFitnessFunction(bestGain, secondBestGain);
+    BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+    BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
   }
 }
 
@@ -612,7 +623,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
 
   // Now the binning should be complete, and so the impurity should be
   // (0.5 * (1 - 0.5)) * 2 = 0.50 (it will be 0 in the two created children).
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.50, 0.03);
+  double bestGain, secondBestGain;
+  split.EvaluateFitnessFunction(bestGain, secondBestGain);
+  BOOST_REQUIRE_CLOSE(bestGain, 0.50, 0.03);
+  BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
 
   // Make sure that if we do create children, that the correct number of
   // children is created, and that the bins end up in the right place.
@@ -647,7 +661,10 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleSplitTest)
     // Now ensure the fitness function gives good gain.
     // The Gini impurity for the unsplit node is 2 * (0.5^2) = 0.5, and the Gini
     // impurity for the children is 0.
-    BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.5, 1e-5);
+    double bestGain, secondBestGain;
+    split.EvaluateFitnessFunction(bestGain, secondBestGain);
+    BOOST_REQUIRE_CLOSE(bestGain, 0.5, 1e-5);
+    BOOST_REQUIRE_GT(bestGain, secondBestGain);
   }
 
   // Now, when we ask it to split, ensure that the split value is reasonable.
@@ -684,7 +701,10 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
     // The same as the previous test, but with four classes: 4 * (0.25 * 0.75) =
     // 0.75.  We can only split in one place, though, which will give one
     // perfect child, giving a gain of 0.75 - 3 * (1/3 * 2/3) = 0.25.
-    BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.25, 1e-5);
+    double bestGain, secondBestGain;
+    split.EvaluateFitnessFunction(bestGain, secondBestGain);
+    BOOST_REQUIRE_CLOSE(bestGain, 0.25, 1e-5);
+    BOOST_REQUIRE_GT(bestGain, secondBestGain);
   }
 
   // Now, when we ask it to split, ensure that the split value is reasonable.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 385331d..bcfddf0 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -720,12 +720,20 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
   BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
   BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
 
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-      xmlSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-      textSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-      binarySplit.EvaluateFitnessFunction(), 1e-5);
+  double bestSplit, secondBestSplit;
+  double baseBestSplit, baseSecondBestSplit;
+  split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+  xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+  textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+  binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
 
   arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
   NumericSplitInfo<double> splitInfo, xmlSplitInfo, textSplitInfo,
@@ -785,10 +793,24 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
   BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
   BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
 
-  BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_SMALL(textSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_SMALL(xmlSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_SMALL(binarySplit.EvaluateFitnessFunction(), 1e-5);
+  double baseBestSplit, baseSecondBestSplit;
+  double bestSplit, secondBestSplit;
+  split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+  textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+
+  BOOST_REQUIRE_SMALL(baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(baseSecondBestSplit, 1e-5);
+
+  BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
+
+  xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
+
+  binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
 }
 
 /**
@@ -814,12 +836,21 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
   BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
   BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
 
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-                      xmlSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-                      textSplit.EvaluateFitnessFunction(), 1e-5);
-  BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
-                      binarySplit.EvaluateFitnessFunction(), 1e-5);
+  double bestSplit, secondBestSplit;
+  double baseBestSplit, baseSecondBestSplit;
+  split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+  xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+  textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+  binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+  BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+  BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
 
   arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
   CategoricalSplitInfo splitInfo(1); // I don't care about this.



More information about the mlpack-git mailing list