[mlpack-git] master: Oops, did not apply changes to other overload. (9d8c552)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun May 3 21:37:57 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/076156df78e26ba87012f2b5fbc6d45e84da918b...744b3268b46dbd04fc42b343d992bceda121bc11

>---------------------------------------------------------------

commit 9d8c55236df28ad3a8856f835cb4926e258dee7a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun May 3 21:37:37 2015 -0400

    Oops, did not apply changes to other overload.


>---------------------------------------------------------------

9d8c55236df28ad3a8856f835cb4926e258dee7a
 .../tree/binary_space_tree/mean_split_impl.hpp     | 54 ++++++++++++++++++++--
 1 file changed, 49 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
index 1fe630b..9a6c041 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -70,12 +70,18 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
     splitVal += data(splitDimension, i);
   splitVal /= count;
 
+  Log::Assert(splitVal >= bound[splitDimension].Lo());
+  Log::Assert(splitVal <= bound[splitDimension].Hi());
+
   // Perform the actual splitting.  This will order the dataset such that points
   // with value in dimension splitDimension less than or equal to splitVal are
   // on the left of splitCol, and points with value in dimension splitDimension
   // greater than splitVal are on the right side of splitCol.
   splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
 
+  // Increment the tie-breaker counter.
+  ++nextDimension;
+
   return true;
 }
 
@@ -91,22 +97,54 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   double maxWidth = -1;
 
   // Find the split dimension.
+  size_t ties = 0;
   for (size_t d = 0; d < data.n_rows; d++)
   {
-    double width = bound[d].Width();
+    const double width = bound[d].Width();
 
     if (width > maxWidth)
     {
       maxWidth = width;
       splitDimension = d;
     }
+    if (width == maxWidth)
+    {
+      // There's a tie.  Record that.
+      ++ties;
+    }
+  }
+
+  if (ties > 0)
+  {
+    // Look through a second time, and determine the correct dimension.
+    size_t tieIndex = 0;
+    for (size_t d = 0; d < data.n_rows; ++d)
+    {
+      const double width = bound[d].Width();
+
+      if (width == maxWidth)
+      {
+        if (tieIndex == (nextDimension % ties))
+        {
+          splitDimension = d;
+          break;
+        }
+        ++tieIndex;
+      }
+    }
   }
 
   if (maxWidth == 0) // All these points are the same.  We can't split.
     return false;
 
-  // Split in the middle of that dimension.
-  double splitVal = bound[splitDimension].Mid();
+  // Split in the mean of that dimension.
+  double splitVal = 0.0;
+  for (size_t i = begin; i < begin + count; ++i)
+    splitVal += data(splitDimension, i);
+  splitVal /= count;
+
+  Log::Assert(splitVal >= bound[splitDimension].Lo());
+  Log::Assert(splitVal <= bound[splitDimension].Hi());
 
   // Perform the actual splitting.  This will order the dataset such that points
   // with value in dimension splitDimension less than or equal to splitVal are
@@ -115,6 +153,9 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
       oldFromNew);
 
+  // Increment the tie-breaker counter.
+  ++nextDimension;
+
   return true;
 }
 
@@ -161,8 +202,8 @@ size_t MeanSplit<BoundType, MatType>::
 
   Log::Assert(left == right + 1);
 
-  if (left >= begin + count)
-    Log::Fatal << "Left is count. Bad.\n";
+  Log::Assert(left > begin);
+  Log::Assert(left < begin + count - 1);
 
   return left;
 }
@@ -216,6 +257,9 @@ size_t MeanSplit<BoundType, MatType>::
 
   Log::Assert(left == right + 1);
 
+  Log::Assert(left > begin);
+  Log::Assert(left < begin + count - 1);
+
   return left;
 }
 



More information about the mlpack-git mailing list