[mlpack-git] master: Refactor to allow passing parameters of splits to children. (b69fc07)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:32 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit b69fc07d75ff34fde9ab344200b226de260f48d2
Author: ryan <ryan at ratml.org>
Date: Mon Nov 23 13:40:48 2015 -0500
Refactor to allow passing parameters of splits to children.
>---------------------------------------------------------------
b69fc07d75ff34fde9ab344200b226de260f48d2
.../hoeffding_trees/binary_numeric_split.hpp | 8 ++++
.../hoeffding_trees/binary_numeric_split_impl.hpp | 12 +++++
.../hoeffding_categorical_split.hpp | 10 ++++
.../hoeffding_categorical_split_impl.hpp | 10 ++++
.../hoeffding_trees/hoeffding_numeric_split.hpp | 7 +++
.../hoeffding_numeric_split_impl.hpp | 15 ++++++
.../methods/hoeffding_trees/hoeffding_tree.hpp | 10 +++-
.../hoeffding_trees/hoeffding_tree_impl.hpp | 55 ++++++++++++++++++----
.../hoeffding_trees/hoeffding_tree_main.cpp | 1 +
9 files changed, 117 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
index faa0f77..5552686 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -53,6 +53,14 @@ class BinaryNumericSplit
BinaryNumericSplit(const size_t numClasses);
/**
+ * Create the BinaryNumericSplit object with the given number of classes,
+ * using information from the given other split for other parameters. In this
+ * case, there are no other parameters, but this function is required by the
+ * HoeffdingTree class.
+ */
+ BinaryNumericSplit(const size_t numClasses, const BinaryNumericSplit& other);
+
+ /**
* Train on the given value with the given label.
*
* @param value The value to train on.
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index a0a174a..eca9222 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -25,6 +25,18 @@ BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
}
template<typename FitnessFunction, typename ObservationType>
+BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
+ const size_t numClasses,
+ const BinaryNumericSplit& /* other */) :
+ classCounts(numClasses),
+ bestSplit(std::numeric_limits<ObservationType>::min()),
+ isAccurate(true)
+{
+ // Zero out class counts.
+ classCounts.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
void BinaryNumericSplit<FitnessFunction, ObservationType>::Train(
ObservationType value,
const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 291fb75..f72ffe0 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -53,6 +53,16 @@ class HoeffdingCategoricalSplit
const size_t numClasses);
/**
+ * Create the HoeffdingCategoricalSplit given a number of categories for this
+ * dimension and a number of classes and another HoeffdingCategoricalSplit to
+ * take parameters from. In this particular case, there are no parameters to
+ * take, but this constructor is required by the HoeffdingTree class.
+ */
+ HoeffdingCategoricalSplit(const size_t numCategories,
+ const size_t numClasses,
+ const HoeffdingCategoricalSplit& other);
+
+ /**
* Train on the given value with the given label.
*
* @param value Value to train on.
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index c008f17..c6d6521 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -23,6 +23,16 @@ HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
}
template<typename FitnessFunction>
+HoeffdingCategoricalSplit<FitnessFunction>::HoeffdingCategoricalSplit(
+ const size_t numCategories,
+ const size_t numClasses,
+ const HoeffdingCategoricalSplit& /* other */) :
+ sufficientStatistics(numClasses, numCategories)
+{
+ sufficientStatistics.zeros();
+}
+
+template<typename FitnessFunction>
template<typename eT>
void HoeffdingCategoricalSplit<FitnessFunction>::Train(eT value,
const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index fb6c4c1..4a5e390 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -65,6 +65,13 @@ class HoeffdingNumericSplit
const size_t observationsBeforeBinning = 100);
/**
+ * Create the HoeffdingNumericSplit class, using the parameters from the given
+ * other split object.
+ */
+ HoeffdingNumericSplit(const size_t numClasses,
+ const HoeffdingNumericSplit& other);
+
+ /**
* Train the HoeffdingNumericSplit on the given observed value (remember that
* this object only cares about the information for a single feature, not an
* entire point).
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index bcd9f92..6175d89 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -29,6 +29,21 @@ HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
}
template<typename FitnessFunction, typename ObservationType>
+HoeffdingNumericSplit<FitnessFunction, ObservationType>::HoeffdingNumericSplit(
+ const size_t numClasses,
+ const HoeffdingNumericSplit& other) :
+ observations(other.observationsBeforeBinning - 1),
+ labels(other.observationsBeforeBinning - 1),
+ bins(other.bins),
+ observationsBeforeBinning(other.observationsBeforeBinning),
+ samplesSeen(0),
+ sufficientStatistics(arma::zeros<arma::Mat<size_t>>(numClasses, bins))
+{
+ observations.zeros();
+ labels.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Train(
ObservationType value,
const size_t label)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
index 8e53b50..1d0f81c 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
@@ -86,7 +86,11 @@ class HoeffdingTree
const double successProbability = 0.95,
const size_t maxSamples = 0,
const size_t checkInterval = 100,
- const size_t minSamples = 100);
+ const size_t minSamples = 100,
+ const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
+ = CategoricalSplitType<FitnessFunction>(0, 0),
+ const NumericSplitType<FitnessFunction>& numericSplitIn =
+ NumericSplitType<FitnessFunction>(0));
/**
* Construct the Hoeffding tree with the given parameters, but training on no
@@ -113,6 +117,10 @@ class HoeffdingTree
const size_t maxSamples = 0,
const size_t checkInterval = 100,
const size_t minSamples = 100,
+ const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
+ = CategoricalSplitType<FitnessFunction>(0, 0),
+ const NumericSplitType<FitnessFunction>& numericSplitIn =
+ NumericSplitType<FitnessFunction>(0),
std::unordered_map<size_t, std::pair<size_t, size_t>>*
dimensionMappings = NULL);
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index ce0164c..24438d6 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -29,7 +29,10 @@ HoeffdingTree<
const double successProbability,
const size_t maxSamples,
const size_t checkInterval,
- const size_t minSamples) :
+ const size_t minSamples,
+ const CategoricalSplitType<FitnessFunction>&
+ categoricalSplitIn,
+ const NumericSplitType<FitnessFunction>& numericSplitIn) :
dimensionMappings(new std::unordered_map<size_t,
std::pair<size_t, size_t>>()),
ownsMappings(true),
@@ -51,13 +54,14 @@ HoeffdingTree<
if (datasetInfo.Type(i) == data::Datatype::categorical)
{
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
- datasetInfo.NumMappings(i), numClasses));
+ datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
(*dimensionMappings)[i] = std::make_pair(data::Datatype::categorical,
categoricalSplits.size() - 1);
}
else
{
- numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+ numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+ numericSplitIn));
(*dimensionMappings)[i] = std::make_pair(data::Datatype::numeric,
numericSplits.size() - 1);
}
@@ -80,6 +84,9 @@ HoeffdingTree<
const size_t maxSamples,
const size_t checkInterval,
const size_t minSamples,
+ const CategoricalSplitType<FitnessFunction>&
+ categoricalSplitIn,
+ const NumericSplitType<FitnessFunction>& numericSplitIn,
std::unordered_map<size_t, std::pair<size_t, size_t>>*
dimensionMappingsIn) :
dimensionMappings((dimensionMappingsIn != NULL) ? dimensionMappingsIn :
@@ -105,13 +112,14 @@ HoeffdingTree<
if (datasetInfo.Type(i) == data::Datatype::categorical)
{
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
- datasetInfo.NumMappings(i), numClasses));
+ datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
(*dimensionMappings)[i] = std::make_pair(data::Datatype::categorical,
categoricalSplits.size() - 1);
}
else
{
- numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+ numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+ numericSplitIn));
(*dimensionMappings)[i] = std::make_pair(data::Datatype::numeric,
numericSplits.size() - 1);
}
@@ -124,11 +132,12 @@ HoeffdingTree<
if (datasetInfo.Type(i) == data::Datatype::categorical)
{
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
- datasetInfo.NumMappings(i), numClasses));
+ datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
}
else
{
- numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses));
+ numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
+ numericSplitIn));
}
}
}
@@ -541,9 +550,35 @@ void HoeffdingTree<
// We already know what the splitDimension will be.
for (size_t i = 0; i < childMajorities.n_elem; ++i)
{
- children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
- successProbability, maxSamples, checkInterval, minSamples,
- dimensionMappings));
+ // We need to also give our split objects to the new children, so that
+ // parameters for the splits can be passed down. But if we have no
+ // categorical or numeric features, we can't pass anything but the
+ // defaults...
+ if (categoricalSplits.size() == 0)
+ {
+ // Pass a default categorical split.
+ children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+ successProbability, maxSamples, checkInterval, minSamples,
+ CategoricalSplitType<FitnessFunction>(0, numClasses),
+ numericSplits[0], dimensionMappings));
+ }
+ else if (numericSplits.size() == 0)
+ {
+ // Pass a default numeric split.
+ children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+ successProbability, maxSamples, checkInterval, minSamples,
+ categoricalSplits[0], NumericSplitType<FitnessFunction>(numClasses),
+ dimensionMappings));
+ }
+ else
+ {
+ // Pass both splits that we already have.
+ children.push_back(new HoeffdingTree(*datasetInfo, numClasses,
+ successProbability, maxSamples, checkInterval, minSamples,
+ categoricalSplits[0], numericSplits[0], dimensionMappings));
+
+ }
+
children[i]->MajorityClass() = childMajorities[i];
}
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index 6b6634a..dd6caef 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -114,6 +114,7 @@ void PerformActions()
const string labelsFile = CLI::GetParam<string>("labels_file");
const double confidence = CLI::GetParam<double>("confidence");
const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
+ const size_t minSamples = (size_t) CLI::GetParam<size_t>("min_samples");
const string inputModelFile = CLI::GetParam<string>("input_model_file");
const string outputModelFile = CLI::GetParam<string>("output_model_file");
const string testFile = CLI::GetParam<string>("test_file");
More information about the mlpack-git
mailing list