[mlpack-git] master: Add Range() to GiniImpurity, and test HoeffdingSplit. (924d086)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:18 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 924d0869393cbb5a4c4014ac4fba0b493a5ab223
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Sep 22 11:00:49 2015 -0700
Add Range() to GiniImpurity, and test HoeffdingSplit.
>---------------------------------------------------------------
924d0869393cbb5a4c4014ac4fba0b493a5ab223
.../methods/hoeffding_trees/gini_impurity.hpp | 13 +++++++++
.../methods/hoeffding_trees/hoeffding_split.hpp | 2 +-
.../hoeffding_trees/hoeffding_split_impl.hpp | 18 ++++++++----
src/mlpack/tests/hoeffding_tree_test.cpp | 34 ++++++++++++++++++++++
4 files changed, 60 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
index 39114d9..6c295fe 100644
--- a/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
+++ b/src/mlpack/methods/hoeffding_trees/gini_impurity.hpp
@@ -66,6 +66,19 @@ class GiniImpurity
return impurity;
}
+
+ /**
+ * Return the range of the Gini impurity for the given number of classes.
+ * (That is, the difference between the maximum possible value and the minimum
+ * possible value.)
+ */
+ static double Range(const size_t numClasses)
+ {
+ // The best possible case is that only one class exists, which gives a Gini
+ // impurity of 0. The worst possible case is that the classes are evenly
+ // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
+ return 1.0 - (1.0 / double(numClasses));
+ }
};
} // namespace tree
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index a60f06c..716e164 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -32,7 +32,7 @@ class HoeffdingSplit
void Train(const VecType& point, const size_t label);
// 0 if split should not happen; number of splits otherwise.
- size_t SplitCheck() const;
+ size_t SplitCheck();
// Return index that we should go towards.
template<typename VecType>
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index c0d3788..35fb6d6 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -25,7 +25,8 @@ HoeffdingSplit<
numClasses(numClasses),
datasetInfo(datasetInfo),
successProbability(successProbability),
- categoricalSplit(0)
+ categoricalSplit(0),
+ splitDimension(size_t(-1))
{
for (size_t i = 0; i < dimensionality; ++i)
{
@@ -73,10 +74,10 @@ size_t HoeffdingSplit<
FitnessFunction,
NumericSplitType,
CategoricalSplitType
->::SplitCheck() const
+>::SplitCheck()
{
// Do nothing if we've already split.
- if (splitDimension == size_t(-1))
+ if (splitDimension != size_t(-1))
return 0;
// Check the fitness of each dimension. Then we'll use a Hoeffding bound
@@ -117,11 +118,12 @@ size_t HoeffdingSplit<
{
// Split!
splitDimension = largestIndex;
- if (datasetInfo[largestIndex].Type == data::Datatype::categorical)
+ if (datasetInfo.Type(largestIndex) == data::Datatype::categorical)
{
// I don't know if this should be here.
- majorityClass = categoricalSplit[largestIndex].MajorityClass();
- return datasetInfo[largestIndex].NumMappings();
+ std::cout << "split: largest index " << largestIndex << ".\n";
+ majorityClass = categoricalSplits[largestIndex].MajorityClass();
+ return datasetInfo.NumMappings(largestIndex);
}
else
{
@@ -129,6 +131,10 @@ size_t HoeffdingSplit<
return 0; // I have no idea what to do yet.
}
}
+ else
+ {
+ return 0; // Don't split.
+ }
}
template<
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 58d34d9..064b21c 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -197,4 +197,38 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(2), 2);
}
+/**
+ * If we feed the HoeffdingSplit a ton of points of the same class, it should
+ * not suggest that we split.
+ */
+BOOST_AUTO_TEST_CASE(HoeffdingSplitNoSplitTest)
+{
+ // Make all dimensions categorical.
+ data::DatasetInfo info;
+ info.MapString("cat1", 0);
+ info.MapString("cat2", 0);
+ info.MapString("cat3", 0);
+ info.MapString("cat4", 0);
+ info.MapString("cat1", 1);
+ info.MapString("cat2", 1);
+ info.MapString("cat3", 1);
+ info.MapString("cat1", 2);
+ info.MapString("cat2", 2);
+
+ HoeffdingSplit<> split(3, 2, info, 0.95);
+
+ // Feed it samples.
+ for (size_t i = 0; i < 1000; ++i)
+ {
+ // Create the test point.
+ arma::Col<size_t> testPoint(3);
+ testPoint(0) = mlpack::math::RandInt(0, 4);
+ testPoint(1) = mlpack::math::RandInt(0, 3);
+ testPoint(2) = mlpack::math::RandInt(0, 2);
+ split.Train(testPoint, 0); // Always label 0.
+
+ BOOST_REQUIRE_EQUAL(split.SplitCheck(), 0);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list