[mlpack-git] master: Add a way to get the probability at each node. (5268e55)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:50 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 5268e554528ac265417ab0f56244c2bac9449feb
Author: ryan <ryan at ratml.org>
Date: Thu Oct 1 21:08:05 2015 -0400
Add a way to get the probability at each node.
>---------------------------------------------------------------
5268e554528ac265417ab0f56244c2bac9449feb
CMakeLists.txt | 2 +-
.../hoeffding_categorical_split.hpp | 1 +
.../hoeffding_categorical_split_impl.hpp | 10 +++++++-
.../hoeffding_trees/hoeffding_numeric_split.hpp | 1 +
.../hoeffding_numeric_split_impl.hpp | 28 +++++++++++++++++++-
.../methods/hoeffding_trees/hoeffding_split.hpp | 5 ++++
.../hoeffding_trees/hoeffding_split_impl.hpp | 24 +++++++++++++++++
.../hoeffding_trees/streaming_decision_tree.hpp | 7 +++++
.../streaming_decision_tree_impl.hpp | 30 ++++++++++++++++++++++
9 files changed, 105 insertions(+), 3 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8eddc98..995e3a8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -64,7 +64,7 @@ endif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
# Debugging CFLAGS. Turn optimizations off; turn debugging symbols on.
if(DEBUG)
add_definitions(-DDEBUG)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -ftemplate-backtrace-limit=0")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99 -g -O0")
else()
add_definitions(-DARMA_NO_DEBUG)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 22aeb5a..847e10c 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -50,6 +50,7 @@ class HoeffdingCategoricalSplit
void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo);
size_t MajorityClass() const;
+ double MajorityProbability() const;
//! Serialize the categorical split.
template<typename Archive>
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index 37fd8f7..e4a5471 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -62,7 +62,7 @@ template<typename FitnessFunction>
size_t HoeffdingCategoricalSplit<FitnessFunction>::MajorityClass() const
{
// Calculate the class that we have seen the most of.
- arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);
+ arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
arma::uword maxIndex;
classCounts.max(maxIndex);
@@ -70,6 +70,14 @@ size_t HoeffdingCategoricalSplit<FitnessFunction>::MajorityClass() const
return size_t(maxIndex);
}
+template<typename FitnessFunction>
+double HoeffdingCategoricalSplit<FitnessFunction>::MajorityProbability() const
+{
+ arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
+
+ return double(classCounts.max()) / double(arma::accu(classCounts));
+}
+
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index d99045e..2e233e6 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -60,6 +60,7 @@ class HoeffdingNumericSplit
void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo) const;
size_t MajorityClass() const;
+ double MajorityProbability() const;
size_t Bins() const { return bins; }
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index fa156a6..da426f1 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -130,7 +130,7 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
{
// We've calculated the bins, so we can just sum over the sufficient
// statistics.
- arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);
+ arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
arma::uword maxIndex;
classCounts.max(maxIndex);
@@ -139,6 +139,32 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
}
template<typename FitnessFunction, typename ObservationType>
+double HoeffdingNumericSplit<FitnessFunction, ObservationType>::
+ MajorityProbability() const
+{
+ // If we haven't yet determined the bins, we must calculate this by hand.
+ if (samplesSeen < observationsBeforeBinning)
+ {
+ arma::Col<size_t> classes(sufficientStatistics.n_rows);
+ classes.zeros();
+
+ for (size_t i = 0; i < samplesSeen; ++i)
+ classes[labels[i]]++;
+
+ return double(classes.max()) / double(arma::accu(classes));
+ }
+ else
+ {
+ // We've calculated the bins, so we can just sum over the sufficient
+ // statistics.
+ arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
+
+ return double(classCounts.max()) / double(arma::sum(classCounts));
+ }
+
+}
+
+template<typename FitnessFunction, typename ObservationType>
template<typename Archive>
void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Serialize(
Archive& ar,
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index c499154..139099b 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -55,6 +55,10 @@ class HoeffdingSplit
template<typename VecType>
size_t Classify(const VecType& point) const;
+ template<typename VecType>
+ void Classify(const VecType& point, size_t& prediction, double& probability)
+ const;
+
template<typename StreamingDecisionTreeType>
void CreateChildren(std::vector<StreamingDecisionTreeType>& children);
@@ -81,6 +85,7 @@ class HoeffdingSplit
// And we need to keep some information for after we have split.
size_t splitDimension;
size_t majorityClass;
+ double majorityProbability;
typename CategoricalSplitType::SplitInfo categoricalSplit; // In case it's categorical.
typename NumericSplitType::SplitInfo numericSplit; // In case it's numeric.
};
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index 8a9ca73..e32d251 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -108,9 +108,15 @@ void HoeffdingSplit<
// Grab majority class from splits.
if (categoricalSplits.size() > 0)
+ {
majorityClass = categoricalSplits[0].MajorityClass();
+ majorityProbability = categoricalSplits[0].MajorityProbability();
+ }
else
+ {
majorityClass = numericSplits[0].MajorityClass();
+ majorityProbability = numericSplits[0].MajorityProbability();
+ }
}
else
{
@@ -263,6 +269,24 @@ template<
typename NumericSplitType,
typename CategoricalSplitType
>
+template<typename VecType>
+void HoeffdingSplit<
+ FitnessFunction,
+ NumericSplitType,
+ CategoricalSplitType
+>::Classify(const VecType& /* point */,
+ size_t& prediction,
+ double& probability) const
+{
+ prediction = majorityClass;
+ probability = majorityProbability;
+}
+
+template<
+ typename FitnessFunction,
+ typename NumericSplitType,
+ typename CategoricalSplitType
+>
template<typename StreamingDecisionTreeType>
void HoeffdingSplit<
FitnessFunction,
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
index c902439..89856e5 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
@@ -52,8 +52,15 @@ class StreamingDecisionTree
template<typename VecType>
size_t Classify(const VecType& data);
+ template<typename VecType>
+ void Classify(const VecType& data, size_t& prediction, double& probability);
+
void Classify(const MatType& data, arma::Row<size_t>& predictions);
+ void Classify(const MatType& data,
+ arma::Row<size_t>& predictions,
+ arma::rowvec& probabilities);
+
size_t& MajorityClass() { return split.MajorityClass(); }
// How do we encode the actual split itself?
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
index 31e9857..8697a05 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
@@ -105,6 +105,24 @@ size_t StreamingDecisionTree<SplitType, MatType>::Classify(const VecType& data)
}
template<typename SplitType, typename MatType>
+template<typename VecType>
+void StreamingDecisionTree<SplitType, MatType>::Classify(
+ const VecType& data,
+ size_t& prediction,
+ double& probability)
+{
+ if (children.size() == 0)
+ {
+ split.Classify(data, prediction, probability);
+ }
+ else
+ {
+ const size_t direction = split.CalculateDirection(data);
+ children[direction].Classify(data, prediction, probability);
+ }
+}
+
+template<typename SplitType, typename MatType>
void StreamingDecisionTree<SplitType, MatType>::Classify(
const MatType& data,
arma::Row<size_t>& predictions)
@@ -114,6 +132,18 @@ void StreamingDecisionTree<SplitType, MatType>::Classify(
predictions[i] = Classify(data.col(i));
}
+template<typename SplitType, typename MatType>
+void StreamingDecisionTree<SplitType, MatType>::Classify(
+ const MatType& data,
+ arma::Row<size_t>& predictions,
+ arma::rowvec& probabilities)
+{
+ predictions.set_size(data.n_cols);
+ probabilities.set_size(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; ++i)
+ Classify(data.col(i), predictions[i], probabilities[i]);
+}
+
} // namespace tree
} // namespace mlpack
More information about the mlpack-git
mailing list