[mlpack-git] master: Clarify where splitting actually occurs; fix bugs. (9a4fb1d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:01 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 9a4fb1d85d8997487f5dd8c94cdae2215d1e82e8
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Nov 5 14:34:18 2015 -0800

    Clarify where splitting actually occurs; fix bugs.


>---------------------------------------------------------------

9a4fb1d85d8997487f5dd8c94cdae2215d1e82e8
 .../hoeffding_trees/binary_numeric_split.hpp       |  4 +-
 .../hoeffding_trees/binary_numeric_split_impl.hpp  | 45 +++++++++++++++-------
 2 files changed, 34 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
index 07f8d7b..f83c95e 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split.hpp
@@ -8,6 +8,8 @@
 #ifndef __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
 #define __MLPACK_METHODS_HOEFFDING_SPLIT_BINARY_NUMERIC_SPLIT_HPP
 
+#include "binary_numeric_split_info.hpp"
+
 namespace mlpack {
 namespace tree {
 
@@ -41,7 +43,7 @@ class BinaryNumericSplit
 {
  public:
   //! The splitting information required by the BinaryNumericSplit.
-  typedef NumericSplitInfo<ObservationType> SplitInfo;
+  typedef BinaryNumericSplitInfo<ObservationType> SplitInfo;
 
   /**
    * Create the BinaryNumericSplit object with the given number of classes.
diff --git a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
index f3e587d..468bfdd 100644
--- a/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/binary_numeric_split_impl.hpp
@@ -51,20 +51,29 @@ double BinaryNumericSplit<FitnessFunction, ObservationType>::
 
   double bestValue = FitnessFunction::Evaluate(counts);
 
+  // Initialize to the first observation, so we don't calculate gain on the
+  // first iteration (it will be 0).
+  ObservationType lastObservation = (*sortedElements.begin()).first;
   for (typename std::multimap<ObservationType, size_t>::const_iterator it =
       sortedElements.begin(); it != sortedElements.end(); ++it)
   {
+    // If this value is the same as the last, or if this is the first value,
+    // don't calculate the gain.
+    if ((*it).first != lastObservation)
+    {
+      lastObservation = (*it).first;
+
+      const double value = FitnessFunction::Evaluate(counts);
+      if (value > bestValue)
+      {
+        bestValue = value;
+        bestSplit = (*it).first;
+      }
+    }
+
     // Move the point to the right side of the split.
     --counts((*it).second, 1);
     ++counts((*it).second, 0);
-
-    // TODO: skip ahead if the next value is the same.
-    const double value = FitnessFunction::Evaluate(counts);
-    if (value > bestValue)
-    {
-      bestValue = value;
-      bestSplit = (*it).first;
-    }
   }
 
   isAccurate = true;
@@ -86,12 +95,22 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
   counts.col(0).zeros();
   counts.col(1) = classCounts;
 
+  double min = DBL_MAX;
+  double max = -DBL_MAX;
   for (typename std::multimap<ObservationType, size_t>::const_iterator it =
-      sortedElements.begin(); (*it).first < bestSplit; ++it)
+      sortedElements.begin();// (*it).first < bestSplit; ++it)
+      it != sortedElements.end(); ++it)
   {
     // Move the point to the correct side of the split.
-    --counts((*it).second, 1);
-    ++counts((*it).second, 0);
+    if ((*it).first < bestSplit)
+    {
+      --counts((*it).second, 1);
+      ++counts((*it).second, 0);
+    }
+    if ((*it).first < min)
+      min = (*it).first;
+    if ((*it).first > max)
+      max = (*it).first;
   }
 
   // Calculate the majority classes of the children.
@@ -102,9 +121,7 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
   childMajorities[1] = size_t(maxIndex);
 
   // Create the according SplitInfo object.
-  arma::vec splitPoints(1);
-  splitPoints[0] = double(bestSplit);
-  splitInfo = SplitInfo(splitPoints);
+  splitInfo = SplitInfo(bestSplit);
 }
 
 template<typename FitnessFunction, typename ObservationType>



More information about the mlpack-git mailing list