[mlpack-git] master: Mild refactoring; Split() instead of CreateChildren(). (17a7986)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:07 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 17a7986e12ae9790447440272436442287fc903b
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 30 13:52:42 2015 -0400
Mild refactoring; Split() instead of CreateChildren().
>---------------------------------------------------------------
17a7986e12ae9790447440272436442287fc903b
.../hoeffding_categorical_split.hpp | 6 +----
.../hoeffding_categorical_split_impl.hpp | 15 ++++++-----
.../hoeffding_trees/hoeffding_numeric_split.hpp | 9 +++----
.../hoeffding_numeric_split_impl.hpp | 21 +++++----------
.../hoeffding_trees/hoeffding_split_impl.hpp | 31 ++++++++++------------
.../methods/hoeffding_trees/numeric_split_info.hpp | 2 +-
.../hoeffding_trees/streaming_decision_tree.hpp | 6 +++--
.../streaming_decision_tree_impl.hpp | 10 ++++---
.../streaming_decision_tree_main.cpp | 5 +++-
src/mlpack/tests/hoeffding_tree_test.cpp | 13 +++++----
10 files changed, 54 insertions(+), 64 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 2615ae0..6c2edcf 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -47,11 +47,7 @@ class HoeffdingCategoricalSplit
double EvaluateFitnessFunction() const;
- template<typename StreamingDecisionTreeType>
- void CreateChildren(std::vector<StreamingDecisionTreeType>& children,
- const data::DatasetInfo& datasetInfo,
- const size_t dimensionality,
- SplitInfo& splitInfo);
+ void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo);
size_t MajorityClass() const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index e86428b..37fd8f7 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -41,17 +41,18 @@ double HoeffdingCategoricalSplit<FitnessFunction>::EvaluateFitnessFunction()
}
template<typename FitnessFunction>
-template<typename StreamingDecisionTreeType>
-void HoeffdingCategoricalSplit<FitnessFunction>::CreateChildren(
- std::vector<StreamingDecisionTreeType>& children,
- const data::DatasetInfo& datasetInfo,
- const size_t dimensionality,
+void HoeffdingCategoricalSplit<FitnessFunction>::Split(
+ arma::Col<size_t>& childMajorities,
SplitInfo& splitInfo)
{
// We'll make one child for each category.
+ childMajorities.set_size(sufficientStatistics.n_cols);
for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
- children.push_back(StreamingDecisionTreeType(datasetInfo, dimensionality,
- sufficientStatistics.n_rows));
+ {
+ arma::uword maxIndex;
+ sufficientStatistics.col(i).max(maxIndex);
+ childMajorities[i] = size_t(maxIndex);
+ }
// Create the according SplitInfo object.
splitInfo = SplitInfo(sufficientStatistics.n_cols);
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index 567fd76..e4ba4f0 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -55,12 +55,9 @@ class HoeffdingNumericSplit
double EvaluateFitnessFunction() const;
- // Does nothing for now.
- template<typename StreamingDecisionTreeType>
- void CreateChildren(std::vector<StreamingDecisionTreeType>& children,
- const data::DatasetInfo& datasetInfo,
- const size_t dimensionality,
- SplitInfo& splitInfo);
+ // Return the majority class of each child to be created, if a split on this
+ // dimension was performed. Also create the split object.
+ void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo) const;
size_t MajorityClass() const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index 81269b7..f6a1284 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -94,23 +94,16 @@ double HoeffdingNumericSplit<FitnessFunction, ObservationType>::
}
template<typename FitnessFunction, typename ObservationType>
-template<typename StreamingDecisionTreeType>
-void HoeffdingNumericSplit<FitnessFunction, ObservationType>::CreateChildren(
- std::vector<StreamingDecisionTreeType>& children,
- const data::DatasetInfo& datasetInfo,
- const size_t dimensionality,
- SplitInfo& splitInfo)
+void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Split(
+ arma::Col<size_t>& childMajorities,
+ SplitInfo& splitInfo) const
{
- // We'll make one child for each bin.
+ childMajorities.set_size(sufficientStatistics.n_cols);
for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
{
- // We need to set the majority class for the child, too.
- children.push_back(StreamingDecisionTreeType(datasetInfo, dimensionality,
- sufficientStatistics.n_rows));
-
- arma::uword majorityClass;
- sufficientStatistics.col(i).max(majorityClass);
- children[i].MajorityClass() = majorityClass;
+ arma::uword maxIndex;
+ sufficientStatistics.col(i).max(maxIndex);
+ childMajorities[i] = size_t(maxIndex);
}
// Create the SplitInfo object.
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index aad4ef2..0c2b60f 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -209,29 +209,26 @@ void HoeffdingSplit<
CategoricalSplitType
>::CreateChildren(std::vector<StreamingDecisionTreeType>& children)
{
- // We already know what the splitDimension will be.
- size_t numericSplitIndex = 0;
- size_t categoricalSplitIndex = 0;
- for (size_t i = 0; i < splitDimension; ++i)
- {
- if (datasetInfo.Type(i) == data::Datatype::numeric)
- ++numericSplitIndex;
- if (datasetInfo.Type(i) == data::Datatype::categorical)
- ++categoricalSplitIndex;
- }
-
// Create the children.
+ arma::Col<size_t> childMajorities;
if (dimensionMappings[splitDimension].first == data::Datatype::categorical)
{
- categoricalSplits[dimensionMappings[splitDimension].second].CreateChildren(
- children, datasetInfo, numericSplits.size() + categoricalSplits.size(),
- categoricalSplit);
+ categoricalSplits[dimensionMappings[splitDimension].second].Split(
+ childMajorities, categoricalSplit);
}
else if (dimensionMappings[splitDimension].first == data::Datatype::numeric)
{
- numericSplits[dimensionMappings[splitDimension].second].CreateChildren(
- children, datasetInfo, numericSplits.size() + categoricalSplits.size(),
- numericSplit);
+ numericSplits[dimensionMappings[splitDimension].second].Split(
+ childMajorities, numericSplit);
+ }
+
+ // We already know what the splitDimension will be.
+ const size_t dimensionality = numericSplits.size() + categoricalSplits.size();
+ for (size_t i = 0; i < childMajorities.n_elem; ++i)
+ {
+ children.push_back(StreamingDecisionTreeType(datasetInfo, dimensionality,
+ classCounts.n_elem, successProbability, numSamples));
+ children[i].MajorityClass() = childMajorities[i];
}
}
diff --git a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
index ce58e58..1d751dc 100644
--- a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
@@ -17,7 +17,7 @@ class NumericSplitInfo
{
public:
NumericSplitInfo() { /* Nothing to do. */ }
- NumericSplitInfo(arma::Col<ObservationType>& splitPoints) :
+ NumericSplitInfo(const arma::Col<ObservationType>& splitPoints) :
splitPoints(splitPoints) { /* Nothing to do. */ }
template<typename eT>
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
index ff60354..71b4315 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
@@ -23,12 +23,14 @@ class StreamingDecisionTree
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
const size_t numClasses,
- const double confidence = 0.95);
+ const double confidence = 0.95,
+ const size_t numSamples = 5000);
StreamingDecisionTree(const data::DatasetInfo& datasetInfo,
const size_t dimensionality,
const size_t numClasses,
- const double confidence = 0.95);
+ const double confidence = 0.95,
+ const size_t numSamples = 5000);
StreamingDecisionTree(const StreamingDecisionTree& other);
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
index 4026927..0c0a3b5 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
@@ -19,8 +19,9 @@ StreamingDecisionTree<SplitType, MatType>::StreamingDecisionTree(
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
const size_t numClasses,
- const double confidence) :
- split(data.n_rows, numClasses, datasetInfo, confidence, 1500)
+ const double confidence,
+ const size_t numSamples) :
+ split(data.n_rows, numClasses, datasetInfo, confidence, numSamples)
{
Train(data, labels);
}
@@ -30,8 +31,9 @@ StreamingDecisionTree<SplitType, MatType>::StreamingDecisionTree(
const data::DatasetInfo& datasetInfo,
const size_t dimensionality,
const size_t numClasses,
- const double confidence) :
- split(dimensionality, numClasses, datasetInfo, confidence, 1500)
+ const double confidence,
+ const size_t numSamples) :
+ split(dimensionality, numClasses, datasetInfo, confidence, numSamples)
{
// No training. Anything else to do...?
}
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 4c2ad7e..949614d 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -19,6 +19,8 @@ PARAM_STRING("labels_file", "Labels for training dataset.", "l", "");
PARAM_DOUBLE("confidence", "Confidence before splitting (between 0 and 1).",
"c", 0.95);
+PARAM_INT("max_samples", "Maximum number of samples before splitting.", "m",
+ 5000);
int main(int argc, char** argv)
@@ -28,6 +30,7 @@ int main(int argc, char** argv)
const string trainingFile = CLI::GetParam<string>("training_file");
const string labelsFile = CLI::GetParam<string>("labels_file");
const double confidence = CLI::GetParam<double>("confidence");
+ const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
arma::mat trainingSet;
DatasetInfo datasetInfo;
@@ -42,7 +45,7 @@ int main(int argc, char** argv)
// Now create the decision tree.
StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
- max(labels) + 1, confidence);
+ max(labels) + 1, confidence, maxSamples);
// Great. Good job team.
std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index 2a58399..6d628c5 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -212,9 +212,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
HoeffdingCategoricalSplit<GiniImpurity>::SplitInfo splitInfo(3);
// Create the children.
- split.CreateChildren(children, info, 3, splitInfo);
+ arma::Col<size_t> childMajorities;
+ split.Split(childMajorities, splitInfo);
- BOOST_REQUIRE_EQUAL(children.size(), 3);
+ BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 3);
BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0), 0);
BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(1), 1);
BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(2), 2);
@@ -509,12 +510,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingNumericSplitBimodalTest)
// Make sure that if we do create children, that the correct number of
// children is created, and that the bins end up in the right place.
- std::vector<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
- HoeffdingNumericSplit<double>>>> children;
- data::DatasetInfo datasetInfo; // All numeric features -- no change necessary.
NumericSplitInfo<> info;
- split.CreateChildren(children, datasetInfo, 1, info);
- BOOST_REQUIRE_EQUAL(children.size(), 2);
+ arma::Col<size_t> childMajorities;
+ split.Split(childMajorities, info);
+ BOOST_REQUIRE_EQUAL(childMajorities.n_elem, 2);
// Now check the split info.
for (size_t i = 0; i < 10; ++i)
More information about the mlpack-git
mailing list