[mlpack-git] master: Refactor to use BoundTraits. Remove nextDimension. (b9d5716)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon May 4 11:13:40 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/744b3268b46dbd04fc42b343d992bceda121bc11...b9d571606ded7e6261682ce4eddca40aa3015cc3

>---------------------------------------------------------------

commit b9d571606ded7e6261682ce4eddca40aa3015cc3
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon May 4 11:12:39 2015 -0400

    Refactor to use BoundTraits.  Remove nextDimension.
    
    The extra overhead of tracking the number of tied dimensions is overkill, when that situation will happen so infrequently in practice.  It used to happen with the BallBound where all dimensions are always the same, but the BoundTraits solution is better.


>---------------------------------------------------------------

b9d571606ded7e6261682ce4eddca40aa3015cc3
 .../core/tree/binary_space_tree/mean_split.hpp     |  54 ++++-----
 .../tree/binary_space_tree/mean_split_impl.hpp     | 126 +++++++++++----------
 2 files changed, 86 insertions(+), 94 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index f5df981..219cc8c 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -24,13 +24,6 @@ class MeanSplit
 {
  public:
   /**
-   * Instantiate a MeanSplit object.  This will initialize the nextDimension
-   * object, which will determine which dimension to use next, in the instance
-   * of a tie.
-   */
-  MeanSplit() : nextDimension(0) { /* Nothing to do. */ }
-
-  /**
    * Split the node according to the mean value in the dimension with maximum
    * width.
    *
@@ -44,11 +37,11 @@ class MeanSplit
    * @param splitCol The index at which the dataset is divided into two parts
    *    after the rearrangement.
    */
-  bool SplitNode(const BoundType& bound,
-                 MatType& data,
-                 const size_t begin,
-                 const size_t count,
-                 size_t& splitCol);
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitCol);
 
   /**
    * Split the node according to the mean value in the dimension with maximum
@@ -66,12 +59,12 @@ class MeanSplit
    * @param oldFromNew Vector which will be filled with the old positions for
    *    each new point.
    */
-  bool SplitNode(const BoundType& bound,
-                 MatType& data,
-                 const size_t begin,
-                 const size_t count,
-                 size_t& splitCol,
-                 std::vector<size_t>& oldFromNew);
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitCol,
+                        std::vector<size_t>& oldFromNew);
 
  private:
   /**
@@ -86,11 +79,11 @@ class MeanSplit
    * @param splitVal The split in dimension splitDimension is based on this
    *    value.
    */
-  size_t PerformSplit(MatType& data,
-                      const size_t begin,
-                      const size_t count,
-                      const size_t splitDimension,
-                      const double splitVal);
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal);
 
   /**
    * Reorder the dataset into two parts such that they lie on either side of
@@ -106,15 +99,12 @@ class MeanSplit
    * @param oldFromNew Vector which will be filled with the old positions for
    *    each new point.
    */
-  size_t PerformSplit(MatType& data,
-                      const size_t begin,
-                      const size_t count,
-                      const size_t splitDimension,
-                      const double splitVal,
-                      std::vector<size_t>& oldFromNew);
-
-  //! Tracks the next dimension to use in case of a tie.
-  size_t nextDimension;
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal,
+                             std::vector<size_t>& oldFromNew);
 };
 
 }; // namespace tree
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
index 9a6c041..3d72938 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -23,40 +23,47 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   size_t splitDimension = data.n_rows; // Indicate invalid.
   double maxWidth = -1;
 
-  // Find the split dimension.
-  size_t ties = 0;
-  for (size_t d = 0; d < data.n_rows; d++)
+  // Find the split dimension.  If the bound is tight, we only need to consult
+  // the bound's width.
+  if (bound::BoundTraits<BoundType>::HasTightBounds)
   {
-    const double width = bound[d].Width();
-
-    if (width > maxWidth)
+    for (size_t d = 0; d < data.n_rows; d++)
     {
-      maxWidth = width;
-      splitDimension = d;
+      const double width = bound[d].Width();
+  
+      if (width > maxWidth)
+      {
+        maxWidth = width;
+        splitDimension = d;
+      }
     }
-    if (width == maxWidth)
+  }
+  else
+  {
+    // We must individually calculate bounding boxes.
+    math::Range* ranges = new math::Range[data.n_rows];
+    for (size_t i = begin; i < begin + count; ++i)
     {
-      // There's a tie.  Record that.
-      ++ties;
+      // Expand each dimension as necessary.
+      for (size_t d = 0; d < data.n_rows; ++d)
+      {
+        const double val = data(d, i);
+        if (val < ranges[d].Lo())
+          ranges[d].Lo() = val;
+        if (val > ranges[d].Hi())
+          ranges[d].Hi() = val;
+      }
     }
-  }
 
-  if (ties > 0)
-  {
-    // Look through a second time, and determine the correct dimension.
-    size_t tieIndex = 0;
-    for (size_t d = 0; d < data.n_rows; ++d)
+    // Now, which is the widest?
+    for (size_t d = 0; d < data.n_rows; d++)
     {
       const double width = bound[d].Width();
 
-      if (width == maxWidth)
+      if (width > maxWidth)
       {
-        if (tieIndex == (nextDimension % ties))
-        {
-          splitDimension = d;
-          break;
-        }
-        ++tieIndex;
+        maxWidth = width;
+        splitDimension = d;
       }
     }
   }
@@ -79,9 +86,6 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   // greater than splitVal are on the right side of splitCol.
   splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
 
-  // Increment the tie-breaker counter.
-  ++nextDimension;
-
   return true;
 }
 
@@ -96,40 +100,47 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   size_t splitDimension = data.n_rows; // Indicate invalid.
   double maxWidth = -1;
 
-  // Find the split dimension.
-  size_t ties = 0;
-  for (size_t d = 0; d < data.n_rows; d++)
+  // Find the split dimension.  If the bound is tight, we only need to consult
+  // the bound's width.
+  if (bound::BoundTraits<BoundType>::HasTightBounds)
   {
-    const double width = bound[d].Width();
-
-    if (width > maxWidth)
+    for (size_t d = 0; d < data.n_rows; d++)
     {
-      maxWidth = width;
-      splitDimension = d;
+      const double width = bound[d].Width();
+
+      if (width > maxWidth)
+      {
+        maxWidth = width;
+        splitDimension = d;
+      }
     }
-    if (width == maxWidth)
+  }
+  else
+  {
+    // We must individually calculate bounding boxes.
+    math::Range* ranges = new math::Range[data.n_rows];
+    for (size_t i = begin; i < begin + count; ++i)
     {
-      // There's a tie.  Record that.
-      ++ties;
+      // Expand each dimension as necessary.
+      for (size_t d = 0; d < data.n_rows; ++d)
+      {
+        const double val = data(d, i);
+        if (val < ranges[d].Lo())
+          ranges[d].Lo() = val;
+        if (val > ranges[d].Hi())
+          ranges[d].Hi() = val;
+      }
     }
-  }
 
-  if (ties > 0)
-  {
-    // Look through a second time, and determine the correct dimension.
-    size_t tieIndex = 0;
-    for (size_t d = 0; d < data.n_rows; ++d)
+    // Now, which is the widest?
+    for (size_t d = 0; d < data.n_rows; d++)
     {
       const double width = bound[d].Width();
 
-      if (width == maxWidth)
+      if (width > maxWidth)
       {
-        if (tieIndex == (nextDimension % ties))
-        {
-          splitDimension = d;
-          break;
-        }
-        ++tieIndex;
+        maxWidth = width;
+        splitDimension = d;
       }
     }
   }
@@ -153,9 +164,6 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
       oldFromNew);
 
-  // Increment the tie-breaker counter.
-  ++nextDimension;
-
   return true;
 }
 
@@ -178,7 +186,7 @@ size_t MeanSplit<BoundType, MatType>::
   // condition is in the middle.
   while ((data(splitDimension, left) < splitVal) && (left <= right))
     left++;
-  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+  while ((data(splitDimension, right) >= splitVal) && (left <= right) && (right > 0))
     right--;
 
   while (left <= right)
@@ -202,9 +210,6 @@ size_t MeanSplit<BoundType, MatType>::
 
   Log::Assert(left == right + 1);
 
-  Log::Assert(left > begin);
-  Log::Assert(left < begin + count - 1);
-
   return left;
 }
 
@@ -228,7 +233,7 @@ size_t MeanSplit<BoundType, MatType>::
   // condition is in the middle.
   while ((data(splitDimension, left) < splitVal) && (left <= right))
     left++;
-  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+  while ((data(splitDimension, right) >= splitVal) && (left <= right) && (right > 0))
     right--;
 
   while (left <= right)
@@ -257,9 +262,6 @@ size_t MeanSplit<BoundType, MatType>::
 
   Log::Assert(left == right + 1);
 
-  Log::Assert(left > begin);
-  Log::Assert(left < begin + count - 1);
-
   return left;
 }
 



More information about the mlpack-git mailing list