[mlpack-git] master: A first pass at a better numeric split. (650e752)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:08 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 650e752382b2a9ead7edce4adec92562afe19c55
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Oct 7 16:20:45 2015 -0400

    A first pass at a better numeric split.


>---------------------------------------------------------------

650e752382b2a9ead7edce4adec92562afe19c55
 .../hoeffding_trees/binary_numeric_split.hpp       |  72 +++++++++++
 .../hoeffding_trees/binary_numeric_split_impl.hpp  | 142 +++++++++++++++++++++
 2 files changed, 214 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
new file mode 100644
index 0000000..d2d7662
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -0,0 +1,72 @@
+/**
+ * @file binary_numeric_split.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the binary-tree-based numeric splitting procedure
+ * described by Gama, Rocha, and Medas in their KDD 2003 paper.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
+#define __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The BinaryNumericSplit class implements the numeric feature splitting
+ * strategy devised by Gama, Rocha, and Medas in the following paper:
+ *
+ * @code
+ * @inproceedings{gama2003accurate,
+ *    title={Accurate Decision Trees for Mining High-Speed Data Streams},
+ *    author={Gama, J. and Rocha, R. and Medas, P.},
+ *    year={2003},
+ *    booktitle={Proceedings of the Ninth ACM SIGKDD International Conference on
+ *        Knowledge Discovery and Data Mining (KDD '03)},
+ *    pages={523--528}
+ * }
+ * @endcode
+ *
+ * This splitting procedure builds a binary tree on points it has seen so far,
+ * and then EvaluateFitnessFunction() returns the best possible split in O(n)
+ * time, where n is the number of samples seen so far.  Every split with this
+ * split type returns only two splits (greater than or equal to the split point,
+ * and less than the split point).  The Train() function should take O(1) time.
+ */
+template<typename FitnessFunction,
+         typename ObservationType = double>
+class BinaryNumericSplit
+{
+ public:
+  typedef NumericSplitInfo<ObservationType> SplitInfo;
+
+  BinaryNumericSplit(const size_t numClasses);
+
+  void Train(ObservationType value, const size_t label);
+
+  double EvaluateFitnessFunction() const;
+
+  void Split(arma::Col<size_t>& childMajorities, SplitInfo& splitInfo) const;
+
+  size_t MajorityClass() const;
+  double MajorityProbability() const;
+
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+ private:
+  // All we need is ordered access.
+  std::multimap<ObservationType, size_t> sortedElements;
+
+  arma::Col<size_t> classCounts;
+
+  bool isAccurate;
+  ObservationType bestSplit;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "binary_numeric_split_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
new file mode 100644
index 0000000..15da9b2
--- /dev/null
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -0,0 +1,142 @@
+/**
+ * @file binary_numeric_split_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the BinaryNumericSplit class.
+ */
+#ifndef __MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
+#define __MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "binary_numeric_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename FitnessFunction, typename ObservationType>
+BinaryNumericSplit<FitnessFunction, ObservationType>::BinaryNumericSplit(
+    const size_t numClasses) :
+    classCounts(numClasses),
+    isAccurate(true),
+    bestSplit(std::numeric_limits<ObservationType>::min())
+{
+  // Zero out class counts.
+  classCounts.zeros();
+}
+
+template<typename FitnessFunction, typename ObservationType>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Train(
+    ObservationType value,
+    const size_t label)
+{
+  // Push it into the multimap, and update the class counts.
+  sortedElements.insert(std::pair<ObservationType, size_t>(value, label));
+  ++classCounts[label];
+
+  // Whatever we have cached is no longer valid.
+  isAccurate = false;
+}
+
+template<typename FitnessFunction, typename ObservationType>
+double BinaryNumericSplit<FitnessFunction, ObservationType>::
+    EvaluateFitnessFunction() const
+{
+  // Unfortunately, we have to iterate over the map.
+  bestSplit = std::numeric_limits<ObservationType>::min();
+
+  // Initialize the sufficient statistics.
+  arma::Mat<size_t> counts(classCounts.n_elem, 2);
+  counts.col(0).zeros();
+  counts.col(1) = classCounts;
+
+  double bestValue = FitnessFunction::Evaluate(counts);
+
+  for (std::multimap<ObservationType, size_t>::const_iterator it =
+      sortedElements.begin(); it != sortedElements.end(); ++it)
+  {
+    // Move the point to the right side of the split.
+    --counts((*it).second, 1);
+    ++counts((*it).second, 0);
+
+    // TODO: skip ahead if the next value is the same.
+
+    double value = FitnessFunction::Evaluate(counts);
+    if (value > bestValue)
+    {
+      bestValue = value;
+      bestSplit = (*it).first;
+    }
+  }
+
+  isAccurate = true;
+  return bestValue;
+}
+
+template<typename FitnessFunction, typename ObservationType>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
+    arma::Col<size_t>& childMajorities,
+    SplitInfo& splitInfo) const
+{
+  if (!isAccurate)
+    EvaluateFitnessFunction();
+
+  // Make one child for each side of the split.
+  childMajorities.set_size(2);
+
+  arma::Mat<size_t> counts(classCounts.n_elem, 2);
+  counts.col(0).zeros();
+  counts.col(1) = classCounts;
+
+  for (std::multimap<ObservationType, size_t>::const_iterator it =
+      sortedElements.begin(); (*it).second <= bestValue; ++it)
+  {
+    // Move the point to the correct side of the split.
+    --counts((*it).second, 1);
+    ++counts((*it).second, 0);
+  }
+
+  // Calculate the majority classes of the children.
+  arma::uword maxIndex;
+  counts.unsafe_col(0).max(maxIndex);
+  childMajorities[0] = size_t(maxIndex);
+  counts.unsafe_col(1).max(maxIndex);
+  childMajorities[1] = size_t(maxIndex);
+
+  // Create the according SplitInfo object.
+  arma::vec splitPoints(1);
+  splitPoints[0] = double(bestSplit);
+  splitInfo = SplitInfo(splitPoints);
+}
+
+template<typename FitnessFunction, typename ObservationType>
+size_t BinaryNumericSplit<FitnessFunction, ObservationType>::MajorityClass()
+    const
+{
+  arma::uword maxIndex;
+  classCounts.max(maxIndex);
+  return size_t(maxIndex);
+}
+
+template<typename FitnessFunction, typename ObservationType>
+double BinaryNumericSplit<FitnessFunction, ObservationType>::
+    MajorityProbability() const
+{
+  return double(arma::max(classCounts)) / double(arma::accu(classCounts));
+}
+
+template<typename FitnessFunction, typename ObservationType>
+template<typename Archive>
+void BinaryNumericSplit<FitnessFunction, ObservationType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  // Serialize.
+  ar & data::CreateNVP(sortedElements, "sortedElements");
+  ar & data::CreateNVP(classCounts, "classCounts");
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list