[mlpack-git] master: A first pass at a better numeric split. (650e752)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:08 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 650e752382b2a9ead7edce4adec92562afe19c55
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Oct 7 16:20:45 2015 -0400
A first pass at a better numeric split.
>---------------------------------------------------------------
650e752382b2a9ead7edce4adec92562afe19c55
.../hoeffding_trees/binary_numeric_split.hpp | 72 +++++++++++
.../hoeffding_trees/binary_numeric_split_impl.hpp | 142 +++++++++++++++++++++
2 files changed, 214 insertions(+)
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
new file mode 100644
index 0000000..d2d7662
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -0,0 +1,72 @@
+/**
+ * @file binary_numeric_split.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the binary-tree-based numeric splitting procedure
+ * described by Gama, Rocha, and Medas in their KDD 2003 paper.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
+#define __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The BinaryNumericSplit class implements the numeric feature splitting
+ * strategy devised by Gama, Rocha, and Medas in the following paper:
+ *
+ * @code
+ * @inproceedings{gama2003accurate,
+ * title={Accurate Decision Trees for Mining High-Speed Data Streams},
+ * author={Gama, J. and Rocha, R. and Medas, P.},
+ * year={2003},
+ * booktitle={Proceedings of the Ninth ACM SIGKDD International Conference on
+ * Knowledge Discovery and Data Mining (KDD '03)},
+ * pages={523--528}
+ * }
+ * @endcode
+ *
+ * This splitting procedure builds a binary tree on points it has seen so far,
+ * and then EvaluateFitnessFunction() returns the best possible split in O(n)
+ * time, where n is the number of samples seen so far. Every split with this
+ * split type returns only two splits (greater than or equal to the split point,
+ * and less than the split point). The Train() function should take O(1) time.
+ */
+template<typename FitnessFunction,
+ typename ObservationType = double>
+class BinaryNumericSplit
+{
+ public:
+ typedef NumericSplitInfo<ObservationType> SplitInfo;
+
+ BinaryNumericSplit(const size_t numClasses);
+
+ void Train(ObservationType value, const size_t label);
+
+ double EvaluateFitnessFunction() const;
+
+ void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo) const;
+
+ size_t MajorityClass() const;
+ double MajorityProbability() const;
+
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ private:
+ // All we need is ordered access.
+ std::multimap<ObservationType, size_t> sortedElements;
+
+ arma::Col<size_t> classCounts;
+
+ bool isAccurate;
+ ObservationType bestSplit;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "binary_numeric_split_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
new file mode 100644
index 0000000..15da9b2
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -0,0 +1,142 @@
+/**
+ * @file binary_numeric_split_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the BinaryNumericSplit class.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
+#define __MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "binary_numeric_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction, typename ObservationType>
+BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
+ const size_t numClasses) :
+ classCounts(numClasses),
+ isAccurate(true),
+ bestSplit(std::numeric_limits<ObservationType>::min())
+{
+ // Zero out class counts.
+ classCounts.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Train(
+ ObservationType value,
+ const size_t label)
+{
+ // Push it into the multimap, and update the class counts.
+ sortedElements.insert(std::pair<ObservationType, size_t>(value, label));
+ ++classCounts[label];
+
+ // Whatever we have cached is no longer valid.
+ isAccurate = false;
+}
+
+template<typename FitnessFunction, typename ObservationType>
+double BinaryNumericSplit<FitnessFunction, ObservationType>::
+ EvaluateFitnessFunction() const
+{
+ // Unfortunately, we have to iterate over the map.
+ bestSplit = std::numeric_limits<ObservationType>::min();
+
+ // Initialize the sufficient statistics.
+ arma::Mat<size_t> counts(classCounts.n_elem, 2);
+ counts.col(0).zeros();
+ counts.col(1) = classCounts;
+
+ double bestValue = FitnessFunction::Evaluate(counts);
+
+ for (std::multimap<ObservationType, size_t>::const_iterator it =
+ sortedElements.begin(); it != sortedElements.end(); ++it)
+ {
+ // Move the point to the right side of the split.
+ --counts((*it).second, 1);
+ ++counts((*it).second, 0);
+
+ // TODO: skip ahead if the next value is the same.
+
+ double value = FitnessFunction::Evaluate(counts);
+ if (value > bestValue)
+ {
+ bestValue = value;
+ bestSplit = (*it).first;
+ }
+ }
+
+ isAccurate = true;
+ return bestValue;
+}
+
+template<typename FitnessFunction, typename ObservationType>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
+ arma::Col<size_t>& childMajorities,
+ SplitInfo& splitInfo) const
+{
+ if (!isAccurate)
+ EvaluateFitnessFunction();
+
+ // Make one child for each side of the split.
+ childMajorities.set_size(2);
+
+ arma::Mat<size_t> counts(classCounts.n_elem, 2);
+ counts.col(0).zeros();
+ counts.col(1) = classCounts;
+
+ for (std::multimap<ObservationType, size_t>::const_iterator it =
+ sortedElements.begin(); (*it).second <= bestValue; ++it)
+ {
+ // Move the point to the correct side of the split.
+ --counts((*it).second, 1);
+ ++counts((*it).second, 0);
+ }
+
+ // Calculate the majority classes of the children.
+ arma::uword maxIndex;
+ counts.unsafe_col(0).max(maxIndex);
+ childMajorities[0] = size_t(maxIndex);
+ counts.unsafe_col(1).max(maxIndex);
+ childMajorities[1] = size_t(maxIndex);
+
+ // Create the according SplitInfo object.
+ arma::vec splitPoints(1);
+ splitPoints[0] = double(bestSplit);
+ splitInfo = SplitInfo(splitPoints);
+}
+
+template<typename FitnessFunction, typename ObservationType>
+size_t BinaryNumericSplit<FitnessFunction, ObservationType>::MajorityClass()
+ const
+{
+ arma::uword maxIndex;
+ classCounts.max(maxIndex);
+ return size_t(maxIndex);
+}
+
+template<typename FitnessFunction, typename ObservationType>
+double BinaryNumericSplit<FitnessFunction, ObservationType>::
+ MajorityProbability() const
+{
+ return double(arma::max(classCounts)) / double(arma::accu(classCounts));
+}
+
+template<typename FitnessFunction, typename ObservationType>
+template<typename Archive>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ // Serialize.
+ ar & data::CreateNVP(sortedElements, "sortedElements");
+ ar & data::CreateNVP(classCounts, "classCounts");
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list