[mlpack-git] master: Refactor for changed EvaluateFitnessFunction(). (278a5f8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:18 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 278a5f85a43cef76d11d5784cf099f1798f9b354
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Nov 12 15:18:38 2015 -0500
Refactor for changed EvaluateFitnessFunction().
>---------------------------------------------------------------
278a5f85a43cef76d11d5784cf099f1798f9b354
.../hoeffding_trees/hoeffding_tree_impl.hpp | 38 +++++++------
src/mlpack/tests/hoeffding_tree_test.cpp | 34 +++++++++---
src/mlpack/tests/serialization_test.cpp | 63 ++++++++++++++++------
3 files changed, 97 insertions(+), 38 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index c9bf774..5bc1715 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -336,32 +336,40 @@ size_t HoeffdingTree<
const double epsilon = std::sqrt(rSquared *
std::log(1.0 / (1.0 - successProbability)) / (2 * numSamples));
- arma::vec gains(categoricalSplits.size() + numericSplits.size());
- for (size_t i = 0; i < gains.n_elem; ++i)
+ // Find the best and second best possible splits.
+ double largest = -DBL_MAX;
+ size_t largestIndex = 0;
+ double secondLargest = -DBL_MAX;
+ for (size_t i = 0; i < categoricalSplits.size() + numericSplits.size(); ++i)
{
size_t type = dimensionMappings->at(i).first;
size_t index = dimensionMappings->at(i).second;
+
+ // Some split procedures can split multiple ways, but we only care about the
+ // best two splits that can be done in every network.
+ double bestGain;
+ double secondBestGain;
if (type == data::Datatype::categorical)
- gains[i] = categoricalSplits[index].EvaluateFitnessFunction();
+ categoricalSplits[index].EvaluateFitnessFunction(bestGain,
+ secondBestGain);
else if (type == data::Datatype::numeric)
- gains[i] = numericSplits[index].EvaluateFitnessFunction();
- }
+ numericSplits[index].EvaluateFitnessFunction(bestGain, secondBestGain);
- // Now find the largest and second-largest.
- double largest = -DBL_MAX;
- size_t largestIndex = 0;
- double secondLargest = -DBL_MAX;
- for (size_t i = 0; i < gains.n_elem; ++i)
- {
- if (gains[i] > largest)
+ // See if these gains are better than the previous.
+ if (bestGain > largest)
{
secondLargest = largest;
- largest = gains[i];
+ largest = bestGain;
largestIndex = i;
}
- else if (gains[i] > secondLargest)
+ else if (bestGain > secondLargest)
+ {
+ secondLargest = bestGain;
+ }
+
+ if (secondBestGain > secondLargest)
{
- secondLargest = gains[i];
+ secondLargest = secondBestGain;
}
}
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 7e3fee1..584bf27 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -279,7 +279,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitEasyFitnessCheck)
for (size_t i = 0; i < 100; ++i)
split.Train(4, 2);
- BOOST_REQUIRE_GT(split.EvaluateFitnessFunction(), 0.0);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_GT(bestGain, 0.0);
+ BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
}
/**
@@ -291,7 +294,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitNoImprovementFitnessTest)
HoeffdingCategoricalSplit<GiniImpurity> split(2, 2);
// No training has yet happened, so a split would get us nothing.
- BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+ BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
split.Train(0, 0);
split.Train(1, 0);
@@ -299,7 +305,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitNoImprovementFitnessTest)
split.Train(1, 1);
// Now, a split still gets us only 50% accuracy in each split bin.
- BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+ BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
}
/**
@@ -567,7 +575,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitFitnessFunctionTest)
for (size_t i = 0; i < 99; ++i)
{
split.Train(mlpack::math::Random(), mlpack::math::RandInt(5));
- BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-10);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_SMALL(bestGain, 1e-10);
+ BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
}
}
@@ -612,7 +623,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
// Now the binning should be complete, and so the impurity should be
// (0.5 * (1 - 0.5)) * 2 = 0.50 (it will be 0 in the two created children).
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.50, 0.03);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_CLOSE(bestGain, 0.50, 0.03);
+ BOOST_REQUIRE_SMALL(secondBestGain, 1e-10);
// Make sure that if we do create children, that the correct number of
// children is created, and that the bins end up in the right place.
@@ -647,7 +661,10 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleSplitTest)
// Now ensure the fitness function gives good gain.
// The Gini impurity for the unsplit node is 2 * (0.5^2) = 0.5, and the Gini
// impurity for the children is 0.
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.5, 1e-5);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_CLOSE(bestGain, 0.5, 1e-5);
+ BOOST_REQUIRE_GT(bestGain, secondBestGain);
}
// Now, when we ask it to split, ensure that the split value is reasonable.
@@ -684,7 +701,10 @@ BOOST_AUTO_TEST_CASE(BinaryNumericSplitSimpleFourClassSplitTest)
// The same as the previous test, but with four classes: 4 * (0.25 * 0.75) =
// 0.75. We can only split in one place, though, which will give one
// perfect child, giving a gain of 0.75 - 3 * (1/3 * 2/3) = 0.25.
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(), 0.25, 1e-5);
+ double bestGain, secondBestGain;
+ split.EvaluateFitnessFunction(bestGain, secondBestGain);
+ BOOST_REQUIRE_CLOSE(bestGain, 0.25, 1e-5);
+ BOOST_REQUIRE_GT(bestGain, secondBestGain);
}
// Now, when we ask it to split, ensure that the split value is reasonable.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 385331d..bcfddf0 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -720,12 +720,20 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitTest)
BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- xmlSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- textSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- binarySplit.EvaluateFitnessFunction(), 1e-5);
+ double bestSplit, secondBestSplit;
+ double baseBestSplit, baseSecondBestSplit;
+ split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+ xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+ textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+ binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
NumericSplitInfo<double> splitInfo, xmlSplitInfo, textSplitInfo,
@@ -785,10 +793,24 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBeforeBinningTest)
BOOST_REQUIRE_EQUAL(split.Bins(), textSplit.Bins());
BOOST_REQUIRE_EQUAL(split.Bins(), binarySplit.Bins());
- BOOST_REQUIRE_SMALL(split.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_SMALL(textSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_SMALL(xmlSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_SMALL(binarySplit.EvaluateFitnessFunction(), 1e-5);
+ double baseBestSplit, baseSecondBestSplit;
+ double bestSplit, secondBestSplit;
+ split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+ textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+
+ BOOST_REQUIRE_SMALL(baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(baseSecondBestSplit, 1e-5);
+
+ BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
+
+ xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
+
+ binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_SMALL(bestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-5);
}
/**
@@ -814,12 +836,21 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
BOOST_REQUIRE_EQUAL(split.MajorityClass(), textSplit.MajorityClass());
BOOST_REQUIRE_EQUAL(split.MajorityClass(), binarySplit.MajorityClass());
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- xmlSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- textSplit.EvaluateFitnessFunction(), 1e-5);
- BOOST_REQUIRE_CLOSE(split.EvaluateFitnessFunction(),
- binarySplit.EvaluateFitnessFunction(), 1e-5);
+ double bestSplit, secondBestSplit;
+ double baseBestSplit, baseSecondBestSplit;
+ split.EvaluateFitnessFunction(baseBestSplit, baseSecondBestSplit);
+ xmlSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+ textSplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
+
+ binarySplit.EvaluateFitnessFunction(bestSplit, secondBestSplit);
+ BOOST_REQUIRE_CLOSE(bestSplit, baseBestSplit, 1e-5);
+ BOOST_REQUIRE_SMALL(secondBestSplit, 1e-10);
arma::Col<size_t> children, xmlChildren, textChildren, binaryChildren;
CategoricalSplitInfo splitInfo(1); // I don't care about this.
More information about the mlpack-git
mailing list