[mlpack-git] master: Improve expansion of node's bound. Do not include the overlapping buffer in the bounds of both nodes. (8b2dc10)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:59 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 8b2dc102fcf32cc645b6a257be5211117bb14dcf
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 7 12:40:32 2016 -0300

    Improve expansion of node's bound. Do not include the overlapping buffer in the bounds of both nodes.


>---------------------------------------------------------------

8b2dc102fcf32cc645b6a257be5211117bb14dcf
 src/mlpack/core/tree/spill_tree/spill_tree.hpp     | 16 +++++++++++-
 .../core/tree/spill_tree/spill_tree_impl.hpp       | 29 +++++++++++++++-------
 2 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index fd7b227..db7f6c0 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -158,12 +158,14 @@ class SpillTree
    *
    * @param parent Parent of this node.  Its dataset will be modified!
    * @param points Vector of indexes of points to be included in this node.
+   * @param overlapIndex Index where the list of overlapping points starts.
    * @param tau Overlapping size.
    * @param maxLeafSize Size of each leaf in the tree.
    * @param rho Balance threshold.
    */
   SpillTree(SpillTree* parent,
             std::vector<size_t>& points,
+            const size_t overlapIndex,
             const double tau,
             const size_t maxLeafSize = 20,
             const double rho = 0.7);
@@ -361,11 +363,13 @@ class SpillTree
    * Splits the current node, assigning its left and right children recursively.
    *
    * @param points Vector of indexes of points to be included in this node.
+   * @param overlapIndex Index where the list of overlapping points starts.
    * @param maxLeafSize Maximum number of points held in a leaf.
    * @param tau Overlapping size.
    * @param rho Balance threshold.
    */
   void SplitNode(std::vector<size_t>& points,
+                 const size_t overlapIndex,
                  const size_t maxLeafSize,
                  const double tau,
                  const double rho);
@@ -380,6 +384,14 @@ class SpillTree
    * @param points Vector of indexes of points to be included.
    * @param leftPoints Indexes of points to be included in left child.
    * @param rightPoints Indexes of points to be included in right child.
+   * @param overlapIndexLeft Index in leftPoints where the list of overlapping
+         points starts ( [overlapIndexLeft, leftPoints.size()) represents the
+         indexes of the points from the right node that are included in the left
+         node).
+   * @param overlapIndexRight Index in rightPoints where the list of overlapping
+         points starts ( [overlapIndexRight, rightPoints.size()) represents the
+         indexes of the points from the left node that are included in the right
+         node).
    * @return Flag to know if the overlapping buffer was included.
    */
   bool SplitPoints(const size_t splitDimension,
@@ -388,7 +400,9 @@ class SpillTree
                    const double rho,
                    const std::vector<size_t>& points,
                    std::vector<size_t>& leftPoints,
-                   std::vector<size_t>& rightPoints);
+                   std::vector<size_t>& rightPoints,
+                   size_t& overlapIndexLeft,
+                   size_t& overlapIndexRight);
  protected:
   /**
    * A default constructor.  This is meant to only be used with
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 6ce80bd..25c83aa 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -43,7 +43,7 @@ SpillTree(
     points.push_back(i);
 
   // Do the actual splitting of this node.
-  SplitNode(points, maxLeafSize, tau, rho);
+  SplitNode(points, points.size(), maxLeafSize, tau, rho);
 
   // Create the statistic depending on if we are a leaf or not.
   stat = StatisticType(*this);
@@ -77,7 +77,7 @@ SpillTree(
     points.push_back(i);
 
   // Do the actual splitting of this node.
-  SplitNode(points, maxLeafSize, tau, rho);
+  SplitNode(points, points.size(), maxLeafSize, tau, rho);
 
   // Create the statistic depending on if we are a leaf or not.
   stat = StatisticType(*this);
@@ -93,6 +93,7 @@ SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
 SpillTree(
     SpillTree* parent,
     std::vector<size_t>& points,
+    const size_t overlapIndex,
     const double tau,
     const size_t maxLeafSize,
     const double rho) :
@@ -106,7 +107,7 @@ SpillTree(
     dataset(&parent->Dataset()) // Point to the parent's dataset.
 {
   // Perform the actual splitting.
-  SplitNode(points, maxLeafSize, tau, rho);
+  SplitNode(points, overlapIndex, maxLeafSize, tau, rho);
 
   // Create the statistic depending on if we are a leaf or not.
   stat = StatisticType(*this);
@@ -446,12 +447,14 @@ template<typename MetricType,
              class SplitType>
 void SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
     SplitNode(std::vector<size_t>& points,
+              const size_t overlapIndex,
               const size_t maxLeafSize,
               const double tau,
               const double rho)
 {
-  // We need to expand the bounds of this node properly.
-  for(size_t i = 0; i < points->size(); i++)
+  // We need to expand the bounds of this node properly, ignoring overlapping
+  // points (they will be included in the bound of the other node).
+  for(size_t i = 0; i < overlapIndex; i++)
     bound |= dataset->cols(points[i], points[i]);
 
   // Calculate the furthest descendant distance.
@@ -481,17 +484,20 @@ void SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
   }
 
   std::vector<size_t> leftPoints, rightPoints;
+  size_t overlapIndexLeft, overlapIndexRight;
   // Split the node.
   overlappingNode = SplitPoints(splitDimension, splitVal, tau, rho, points,
-      leftPoints, rightPoints);
+      leftPoints, rightPoints, overlapIndexLeft, overlapIndexRight);
 
   // We don't need the information in points, so lets clean it.
   std::vector<size_t>().swap(points);
 
   // Now we will recursively split the children by calling their constructors
   // (which perform this splitting process).
-  left = new SpillTree(this, leftPoints, maxLeafSize, tau, rho);
-  right = new SpillTree(this, rightPoints, maxLeafSize, tau, rho);
+  left = new SpillTree(this, leftPoints, overlapIndexLeft, maxLeafSize, tau,
+      rho);
+  right = new SpillTree(this, rightPoints, overlapIndexRight, maxLeafSize, tau,
+      rho);
 
   // Update count number, to represent the number of descendant points.
   count = left->NumDescendants() + right->NumDescendants();
@@ -523,7 +529,9 @@ bool SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
                 const double rho,
                 const std::vector<size_t>& points,
                 std::vector<size_t>& leftPoints,
-                std::vector<size_t>& rightPoints)
+                std::vector<size_t>& rightPoints,
+                size_t& overlapIndexLeft,
+                size_t& overlapIndexRight)
 {
   std::vector<size_t> leftFrontier, rightFrontier;
 
@@ -550,6 +558,9 @@ bool SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
   const double p2 = double (rightPoints.size() + leftFrontier.size()) /
       points.size();
 
+  overlapIndexLeft = leftPoints.size();
+  overlapIndexRight = rightPoints.size();
+
   if (p1 <= rho && p2 <= rho)
   {
     leftPoints.insert(leftPoints.end(), rightFrontier.begin(),




More information about the mlpack-git mailing list