[mlpack-git] master: Pre-emptive prunes. Potentially a slowdown. Not always, though, I don't think. I may revert this. (b2fc22a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit b2fc22aefd46992b4aa12339322e5a37d695c80b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 19 15:05:51 2015 -0500

    Pre-emptive prunes. Potentially a slowdown. Not always, though, I don't think. I may revert this.


>---------------------------------------------------------------

b2fc22aefd46992b4aa12339322e5a37d695c80b
 src/mlpack/methods/kmeans/dtnn_rules.hpp      |  10 +-
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 178 ++++++++++++++++++++++----
 2 files changed, 163 insertions(+), 25 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 5252050..c2f7873 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -9,7 +9,7 @@
 #ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
 #define __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
 
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/neighbor_search/ns_traversal_info.hpp>
 
 namespace mlpack {
 namespace kmeans {
@@ -39,7 +39,7 @@ class DTNNKMeansRules
                  TreeType& referenceNode,
                  const double oldScore);
 
-  typedef int TraversalInfoType;
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
 
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
@@ -67,7 +67,11 @@ class DTNNKMeansRules
   size_t baseCases;
   size_t scores;
 
-  int traversalInfo;
+  TraversalInfoType traversalInfo;
+
+  size_t lastQueryIndex;
+  size_t lastReferenceIndex;
+  size_t lastBaseCase;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index ca8943c..c87c1da 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -33,9 +33,15 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
     oldFromNewCentroids(oldFromNewCentroids),
     visited(visited),
     baseCases(0),
-    scores(0)
+    scores(0),
+    lastQueryIndex(dataset.n_cols),
+    lastReferenceIndex(centroids.n_cols)
 {
-  // Nothing to do.
+  // We must set the traversal info last query and reference node pointers to
+  // something that is both invalid (i.e. not a tree node) and not NULL.  We'll
+  // use the this pointer.
+  traversalInfo.LastQueryNode() = (TreeType*) this;
+  traversalInfo.LastReferenceNode() = (TreeType*) this;
 }
 
 template<typename MetricType, typename TreeType>
@@ -46,6 +52,10 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
   if (prunedPoints[queryIndex])
     return 0.0; // Returning 0 shouldn't be a problem.
 
+  // If we have already performed this base case, then do not perform it again.
+  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
+    return lastBaseCase;
+
   // Any base cases imply that we will get a result.
   visited[queryIndex] = true;
 
@@ -66,6 +76,11 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
     lowerBounds[queryIndex] = distance;
   }
 
+  // Cache this information for the next time BaseCase() is called.
+  lastQueryIndex = queryIndex;
+  lastReferenceIndex = referenceIndex;
+  lastBaseCase = distance;
+
   return distance;
 }
 
@@ -102,30 +117,144 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   if (queryNode.Stat().Pruned() == centroids.n_cols)
     return DBL_MAX;
 
-  // Get minimum and maximum distances.
-  math::Range distances = queryNode.RangeDistance(&referenceNode);
-  double score = distances.Lo();
-  ++scores;
-  if (distances.Lo() > queryNode.Stat().UpperBound())
+  // This looks a lot like the hackery used in NeighborSearchRules to avoid
+  // distance computations.  We'll use the traversal info to see if a
+  // parent-child or parent-parent prune is possible.
+  const double queryParentDist = queryNode.ParentDistance();
+  const double queryDescDist = queryNode.FurthestDescendantDistance();
+  const double refParentDist = referenceNode.ParentDistance();
+  const double refDescDist = referenceNode.FurthestDescendantDistance();
+  const double lastScore = traversalInfo.LastScore();
+  double adjustedScore;
+  double score = 0.0;
+
+  // We want to set adjustedScore to be the distance between the centroid of the
+  // last query node and last reference node.  We will do this by adjusting the
+  // last score.  In some cases, we can just use the last base case.
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    adjustedScore = traversalInfo.LastBaseCase();
+  }
+  else if (lastScore == 0.0) // Nothing we can do here.
+  {
+    adjustedScore = 0.0;
+  }
+  else
   {
-    // The reference node can own no points in this query node.  We may improve
-    // the lower bound on pruned nodes, though.
-    if (distances.Lo() < queryNode.Stat().LowerBound())
-      queryNode.Stat().LowerBound() = distances.Lo();
+    // The last score is equal to the distance between the centroids minus the
+    // radii of the query and reference bounds along the axis of the line
+    // between the two centroids.  In the best case, these radii are the
+    // furthest descendant distances, but that is not always true.  It would
+    // take too long to calculate the exact radii, so we are forced to use
+    // MinimumBoundDistance() as a lower-bound approximation.
+    const double lastQueryDescDist =
+        traversalInfo.LastQueryNode()->MinimumBoundDistance();
+    const double lastRefDescDist =
+        traversalInfo.LastReferenceNode()->MinimumBoundDistance();
+    adjustedScore = lastScore + lastQueryDescDist;
+    adjustedScore = lastScore + lastRefDescDist;
+  }
 
-    // This assumes that reference clusters don't appear elsewhere in the tree.
-    queryNode.Stat().Pruned() += referenceNode.NumDescendants();
-    score = DBL_MAX;
+  // Assemble an adjusted score.  For nearest neighbor search, this adjusted
+  // score is a lower bound on MinDistance(queryNode, referenceNode) that is
+  // assembled without actually calculating MinDistance().  For furthest
+  // neighbor search, it is an upper bound on
+  // MaxDistance(queryNode, referenceNode).  If the traversalInfo isn't usable
+  // then the node should not be pruned by this.
+  if (traversalInfo.LastQueryNode() == queryNode.Parent())
+  {
+    const double queryAdjust = queryParentDist + queryDescDist;
+    adjustedScore -= queryAdjust;
+  }
+  else if (traversalInfo.LastQueryNode() == &queryNode)
+  {
+    adjustedScore -= queryDescDist;
+  }
+  else
+  {
+    // The last query node wasn't this query node or its parent.  So we force
+    // the adjustedScore to be such that this combination can't be pruned here,
+    // because we don't really know anything about it.
+
+    // It would be possible to modify this section to try and make a prune based
+    // on the query descendant distance and the distance between the query node
+    // and last traversal query node, but this case doesn't actually happen for
+    // kd-trees or cover trees.
+    adjustedScore = 0.0;
+  }
+  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
+  {
+    const double refAdjust = refParentDist + refDescDist;
+    adjustedScore -= refAdjust;
+  }
+  else if (traversalInfo.LastReferenceNode() == &referenceNode)
+  {
+    adjustedScore -= refDescDist;
   }
-  else if (distances.Hi() < queryNode.Stat().UpperBound())
+  else
   {
-    // We can improve the best estimate.
-    queryNode.Stat().UpperBound() = distances.Hi();
-    // If this node has only one descendant, then it may be the owner.
-    if (referenceNode.NumDescendants() == 1)
-      queryNode.Stat().Owner() = (tree::TreeTraits<TreeType>::RearrangesDataset)
-          ? oldFromNewCentroids[referenceNode.Descendant(0)]
-          : referenceNode.Descendant(0);
+    // The last reference node wasn't this reference node or its parent.  So we
+    // force the adjustedScore to be such that this combination can't be pruned
+    // here, because we don't really know anything about it.
+
+    // It would be possible to modify this section to try and make a prune based
+    // on the reference descendant distance and the distance between the
+    // reference node and last traversal reference node, but this case doesn't
+    // actually happen for kd-trees or cover trees.
+    adjustedScore = 0.0;
+  }
+
+  // Now, check if we can prune.
+  //Log::Warn << "adjusted score: " << adjustedScore << ".\n";
+  if (adjustedScore > queryNode.Stat().UpperBound())
+  {
+//    Log::Warn << "Pre-emptive prune!\n";
+    if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
+    {
+      // There isn't any need to set the traversal information because no
+      // descendant combinations will be visited, and those are the only
+      // combinations that would depend on the traversal information.
+      if (adjustedScore < queryNode.Stat().LowerBound())
+      {
+        // If this might affect the lower bound, make it more exact.
+        queryNode.Stat().LowerBound() = queryNode.MinDistance(&referenceNode);
+        ++scores;
+      }
+
+      queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+      score = DBL_MAX;
+    }
+  }
+
+  if (score != DBL_MAX)
+  {
+    // Get minimum and maximum distances.
+    math::Range distances = queryNode.RangeDistance(&referenceNode);
+    score = distances.Lo();
+    ++scores;
+    if (distances.Lo() > queryNode.Stat().UpperBound())
+    {
+      // The reference node can own no points in this query node.  We may
+      // improve the lower bound on pruned nodes, though.
+      if (distances.Lo() < queryNode.Stat().LowerBound())
+        queryNode.Stat().LowerBound() = distances.Lo();
+
+      // This assumes that reference clusters don't appear elsewhere in the
+      // tree.
+      queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+      score = DBL_MAX;
+    }
+    else if (distances.Hi() < queryNode.Stat().UpperBound())
+    {
+      // We can improve the best estimate.
+      queryNode.Stat().UpperBound() = distances.Hi();
+      // If this node has only one descendant, then it may be the owner.
+      if (referenceNode.NumDescendants() == 1)
+        queryNode.Stat().Owner() =
+            (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+            oldFromNewCentroids[referenceNode.Descendant(0)] :
+            referenceNode.Descendant(0);
+    }
   }
 
   // Is everything pruned?
@@ -135,6 +264,11 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     return DBL_MAX;
   }
 
+  // Set traversal information.
+  traversalInfo.LastQueryNode() = &queryNode;
+  traversalInfo.LastReferenceNode() = &referenceNode;
+  traversalInfo.LastScore() = score;
+
   return score;
 }
 



More information about the mlpack-git mailing list