[mlpack-git] master: Make sure the tree doesn't split on the same dimension over and over again when there's a tie. (77893ce)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun May 3 19:34:30 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/174d2de995a3fe343cd92d158730f3afa03e622d...076156df78e26ba87012f2b5fbc6d45e84da918b

>---------------------------------------------------------------

commit 77893ce4a53079bc9ef3e7ebb0b669b8527b9c86
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun May 3 19:12:00 2015 -0400

    Make sure the tree doesn't split on the same dimension over and over again when there's a tie.


>---------------------------------------------------------------

77893ce4a53079bc9ef3e7ebb0b669b8527b9c86
 .../core/tree/binary_space_tree/mean_split.hpp     | 54 +++++++++++++---------
 .../tree/binary_space_tree/mean_split_impl.hpp     | 29 ++++++++++++
 2 files changed, 61 insertions(+), 22 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index 219cc8c..f5df981 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -24,6 +24,13 @@ class MeanSplit
 {
  public:
   /**
+   * Instantiate a MeanSplit object.  This will initialize the nextDimension
+   * object, which will determine which dimension to use next, in the instance
+   * of a tie.
+   */
+  MeanSplit() : nextDimension(0) { /* Nothing to do. */ }
+
+  /**
    * Split the node according to the mean value in the dimension with maximum
    * width.
    *
@@ -37,11 +44,11 @@ class MeanSplit
    * @param splitCol The index at which the dataset is divided into two parts
    *    after the rearrangement.
    */
-  static bool SplitNode(const BoundType& bound,
-                        MatType& data,
-                        const size_t begin,
-                        const size_t count,
-                        size_t& splitCol);
+  bool SplitNode(const BoundType& bound,
+                 MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 size_t& splitCol);
 
   /**
    * Split the node according to the mean value in the dimension with maximum
@@ -59,12 +66,12 @@ class MeanSplit
    * @param oldFromNew Vector which will be filled with the old positions for
    *    each new point.
    */
-  static bool SplitNode(const BoundType& bound,
-                        MatType& data,
-                        const size_t begin,
-                        const size_t count,
-                        size_t& splitCol,
-                        std::vector<size_t>& oldFromNew);
+  bool SplitNode(const BoundType& bound,
+                 MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 size_t& splitCol,
+                 std::vector<size_t>& oldFromNew);
 
  private:
   /**
@@ -79,11 +86,11 @@ class MeanSplit
    * @param splitVal The split in dimension splitDimension is based on this
    *    value.
    */
-  static size_t PerformSplit(MatType& data,
-                             const size_t begin,
-                             const size_t count,
-                             const size_t splitDimension,
-                             const double splitVal);
+  size_t PerformSplit(MatType& data,
+                      const size_t begin,
+                      const size_t count,
+                      const size_t splitDimension,
+                      const double splitVal);
 
   /**
    * Reorder the dataset into two parts such that they lie on either side of
@@ -99,12 +106,15 @@ class MeanSplit
    * @param oldFromNew Vector which will be filled with the old positions for
    *    each new point.
    */
-  static size_t PerformSplit(MatType& data,
-                             const size_t begin,
-                             const size_t count,
-                             const size_t splitDimension,
-                             const double splitVal,
-                             std::vector<size_t>& oldFromNew);
+  size_t PerformSplit(MatType& data,
+                      const size_t begin,
+                      const size_t count,
+                      const size_t splitDimension,
+                      const double splitVal,
+                      std::vector<size_t>& oldFromNew);
+
+  //! Tracks the next dimension to use in case of a tie.
+  size_t nextDimension;
 };
 
 }; // namespace tree
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
index f9d07fe..73edf6f 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -24,6 +24,7 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   double maxWidth = -1;
 
   // Find the split dimension.
+  size_t ties = 0;
   for (size_t d = 0; d < data.n_rows; d++)
   {
     double width = bound[d].Width();
@@ -33,6 +34,31 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
       maxWidth = width;
       splitDimension = d;
     }
+    if (width == maxWidth)
+    {
+      // There's a tie.  Record that.
+      ++ties;
+    }
+  }
+
+  if (ties > 0)
+  {
+    // Look through a second time, and determine the correct dimension.
+    size_t tieIndex = 0;
+    for (size_t d = 0; d < data.n_rows; ++d)
+    {
+      const double width = bound[d].Width();
+
+      if (width == maxWidth)
+      {
+        if (tieIndex == (nextDimension % ties))
+        {
+          splitDimension = d;
+          break;
+        }
+        ++tieIndex;
+      }
+    }
   }
 
   if (maxWidth == 0) // All these points are the same.  We can't split.
@@ -132,6 +158,9 @@ size_t MeanSplit<BoundType, MatType>::
 
   Log::Assert(left == right + 1);
 
+  if (left >= begin + count)
+    Log::Fatal << "Left is count. Bad.\n";
+
   return left;
 }
 



More information about the mlpack-git mailing list