[mlpack-git] master, mlpack-1.0.x: Overhaul NeighborSearchRules to work correctly with TraversalInfo objects. This is related to #243. (e57ad28)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:25 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit e57ad282e684ae2802faedaa0e7ca13ffc1d3d61
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 6 20:17:50 2014 +0000

    Overhaul NeighborSearchRules to work correctly with TraversalInfo objects.  This
    is related to #243.


>---------------------------------------------------------------

e57ad282e684ae2802faedaa0e7ca13ffc1d3d61
 .../neighbor_search/neighbor_search_impl.hpp       |   6 +-
 .../neighbor_search/neighbor_search_rules.hpp      |  29 +++
 .../neighbor_search/neighbor_search_rules_impl.hpp | 235 ++++++++++++++-------
 3 files changed, 186 insertions(+), 84 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 0ff34fc..0905d81 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -184,6 +184,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
 
   // Set the size of the neighbor and distance matrices.
   neighborPtr->set_size(k, querySet.n_cols);
+  neighborPtr->fill(size_t() - 1);
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
   
@@ -210,9 +211,8 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
 
     traverser.Traverse(*queryTree, *referenceTree);
 
-    Log::Info << traverser.NumVisited() << " node combinations were visited.\n";
-    Log::Info << traverser.NumScores() << " node combinations were scored.\n";
-    Log::Info << traverser.NumBaseCases() << " base cases were calculated.\n";
+    Log::Info << rules.Scores() << " node combinations were scored.\n";
+    Log::Info << rules.BaseCases() << " base cases were calculated.\n";
   }
 
   Timer::Stop("computing_neighbors");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 6d24a75..3f46b0d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -8,6 +8,8 @@
 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
 
+#include "ns_traversal_info.hpp"
+
 namespace mlpack {
 namespace neighbor {
 
@@ -73,6 +75,24 @@ class NeighborSearchRules
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  //! Get the number of base cases that have been performed.
+  size_t BaseCases() const { return baseCases; }
+  //! Modify the number of base cases that have been performed.
+  size_t& BaseCases() { return baseCases; }
+
+  //! Get the number of scores that have been performed.
+  size_t Scores() const { return scores; }
+  //! Modify the number of scores that have been performed.
+  size_t& Scores() { return scores; }
+
+  //! Convenience typedef.
+  typedef NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  //! Get the traversal info.
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  //! Modify the traversal info.
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -96,6 +116,15 @@ class NeighborSearchRules
   //! The last base case result.
   double lastBaseCase;
 
+  //! The number of base cases that have been performed.
+  size_t baseCases;
+  //! The number of scores that have been performed.
+  size_t scores;
+
+  //! Traversal info for the parent combination; this is updated by the
+  //! traversal before each call to Score().
+  TraversalInfoType traversalInfo;
+
   /**
    * Recalculate the bound for a given query node.
    */
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index fe103ec..031965c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -26,8 +26,16 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     distances(distances),
     metric(metric),
     lastQueryIndex(querySet.n_cols),
-    lastReferenceIndex(referenceSet.n_cols)
-{ /* Nothing left to do. */ }
+    lastReferenceIndex(referenceSet.n_cols),
+    baseCases(0),
+    scores(0)
+{
+  // We must set the traversal info last query and reference node pointers to
+  // something that is both invalid (i.e. not a tree node) and not NULL.  We'll
+  // use the this pointer.
+  traversalInfo.LastQueryNode() = (TreeType*) this;
+  traversalInfo.LastReferenceNode() = (TreeType*) this;
+}
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline force_inline // Absolutely MUST be inline so optimizations can happen.
@@ -45,11 +53,14 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
 
   double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
                                     referenceSet.unsafe_col(referenceIndex));
+  ++baseCases;
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
-  const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+  const size_t insertPosition = SortPolicy::SortDistance(queryDist,
+      queryIndices, distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))
@@ -68,6 +79,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     const size_t queryIndex,
     TreeType& referenceNode)
 {
+  ++scores; // Count number of Score() calls.
   double distance;
   if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
   {
@@ -124,97 +136,128 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+  ++scores; // Count number of Score() calls.
+
+  // Update our bound.
+  const double bestDistance = CalculateBound(queryNode);
+
+  // Use the traversal info to see if a parent-child or parent-parent prune is
+  // possible.  This is a looser bound than we could make, but it might be
+  // sufficient.
+  const double queryParentDist = queryNode.ParentDistance();
+  const double queryDescDist = queryNode.FurthestDescendantDistance();
+  const double refParentDist = referenceNode.ParentDistance();
+  const double refDescDist = referenceNode.FurthestDescendantDistance();
+  const double score = traversalInfo.LastScore();
+  double adjustedScore;
+
+  // In some cases we can just use the last base case.
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    adjustedScore = traversalInfo.LastBaseCase();
+  }
+  else if (score == 0.0) // Nothing we can do here.
+  {
+    adjustedScore = 0.0;
+  }
+  else
+  {
+    const double lastQueryDescDist =
+        traversalInfo.LastQueryNode()->FurthestDescendantDistance();
+    const double lastRefDescDist =
+        traversalInfo.LastReferenceNode()->FurthestDescendantDistance();
+    adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
+    adjustedScore = SortPolicy::CombineWorst(score, lastRefDescDist);
+  }
+
+  // Assemble an adjusted score.  For nearest neighbor search, this adjusted
+  // score is a lower bound on MinDistance(queryNode, referenceNode) that is
+  // assembled without actually calculating MinDistance().  For furthest
+  // neighbor search, it is an upper bound on
+  // MaxDistance(queryNode, referenceNode).  If the traversalInfo isn't usable
+  // then the node should not be pruned by this.
+  if (traversalInfo.LastQueryNode() == queryNode.Parent())
+  {
+    const double queryAdjust = queryParentDist + queryDescDist;
+    adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
+  }
+  else
+  {
+    adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
+  }
+
+  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
+  {
+    const double refAdjust = refParentDist + refDescDist;
+    adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
+  }
+  else
+  {
+    adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
+  }
+
+  // Can we prune?
+  if (SortPolicy::IsBetter(bestDistance, adjustedScore))
+  {
+    if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
+    {
+      // There isn't any need to set the traversal information because no
+      // descendant combinations will be visited, and those are the only
+      // combinations that would depend on the traversal information.
+      return DBL_MAX;
+    }
+  }
+
   double distance;
   if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
   {
     // The first point in the node is the centroid, so we can calculate the
     // distance between the two points using BaseCase() and then find the
     // bounds.  This is potentially loose for non-ball bounds.
-    bool alreadyDone = false;
-    double baseCase;
-    if (tree::TreeTraits<TreeType>::HasSelfChildren)
+    double baseCase = -1.0;
+    if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+       (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
+       (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
     {
-      // In this case, we may have already calculated the base case.
-      TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode();
-      TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode();
-
-      // Does the query node have the base case cached?
-      if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0)))
-      {
-        baseCase = queryNode.Stat().LastDistance();
-        alreadyDone = true;
-      }
-
-      // Does the reference node have the base case cached?
-      if ((lastQuery != NULL) &&
-          (queryNode.Point(0) == lastQuery->Point(0)))
-      {
-        baseCase = referenceNode.Stat().LastDistance();
-        alreadyDone = true;
-      }
-
-      // Is the query node a self-child, and if so, does the query node's parent
-      // have the base case cached?
-      if ((queryNode.Parent() != NULL) &&
-          (queryNode.Parent()->Point(0) == queryNode.Point(0)))
-      {
-        TreeType* lastParentRef = (TreeType*)
-            queryNode.Parent()->Stat().LastDistanceNode();
-        if (lastParentRef->Point(0) == referenceNode.Point(0))
-        {
-          baseCase = queryNode.Parent()->Stat().LastDistance();
-          alreadyDone = true;
-        }
-      }
-
-      // Is the reference node a self-child, and if so, does the reference
-      // node's parent have the base case cached?
-      if ((referenceNode.Parent() != NULL) &&
-          (referenceNode.Parent()->Point(0) == referenceNode.Point(0)))
-      {
-        TreeType* lastParentRef = (TreeType*)
-            referenceNode.Parent()->Stat().LastDistanceNode();
-        if (lastParentRef->Point(0) == queryNode.Point(0))
-        {
-          baseCase = referenceNode.Parent()->Stat().LastDistance();
-          alreadyDone = true;
-        }
-      }
-    }
-
-    // If we did not find a cached base case, then recalculate it.
-    if (!alreadyDone)
-    {
-      baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+      // We already calculated it.
+      baseCase = traversalInfo.LastBaseCase();
     }
     else
     {
-      // Set lastQueryIndex and lastReferenceIndex, so that BaseCase() does not
-      // duplicate work.
-      lastQueryIndex = queryNode.Point(0);
-      lastReferenceIndex = referenceNode.Point(0);
-      lastBaseCase = baseCase;
+      baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
     }
 
     distance = SortPolicy::CombineBest(baseCase,
         queryNode.FurthestDescendantDistance() +
         referenceNode.FurthestDescendantDistance());
 
-    // Update the last distance calculation for the query and reference nodes.
-    queryNode.Stat().LastDistanceNode() = (void*) &referenceNode;
-    queryNode.Stat().LastDistance() = baseCase;
-    referenceNode.Stat().LastDistanceNode() = (void*) &queryNode;
-    referenceNode.Stat().LastDistance() = baseCase;
+    lastQueryIndex = queryNode.Point(0);
+    lastReferenceIndex = referenceNode.Point(0);
+    lastBaseCase = baseCase;
+
+    traversalInfo.LastBaseCase() = baseCase;
   }
   else
   {
     distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
   }
 
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  if (SortPolicy::IsBetter(distance, bestDistance))
+  {
+    // Set traversal information.
+    traversalInfo.LastQueryNode() = &queryNode;
+    traversalInfo.LastReferenceNode() = &referenceNode;
+    traversalInfo.LastScore() = distance;
 
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+    return distance;
+  }
+  else
+  {
+    // There isn't any need to set the traversal information because no
+    // descendant combinations will be visited, and those are the only
+    // combinations that would depend on the traversal information.
+    return DBL_MAX;
+  }
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -257,6 +300,30 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   // So we will loop over the points in queryNode and the children in queryNode
   // to calculate all five of these quantities.
 
+  // Hm, can we populate our distances vector with estimates from the parent?
+  // This is written specifically for the cover tree and assumes only one point
+  // in a node.
+//  if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0)
+//  {
+//    size_t parentIndexStart = 0;
+//    for (size_t i = 0; i < neighbors.n_rows; ++i)
+//    {
+//      const double pointDistance = distances(i, queryNode.Point(0));
+//      if (pointDistance == DBL_MAX)
+//      {
+//      // Cool, can we take an estimate from the parent?
+//        const double parentWorstBound = distances(distances.n_rows - 1,
+//              queryNode.Parent()->Point(0));
+//        if (parentWorstBound != DBL_MAX)
+//        {
+//          const double parentAdjustedDistance = parentWorstBound +
+//              queryNode.ParentDistance();
+//          distances(i, queryNode.Point(0)) = parentAdjustedDistance;
+//        }
+//      }
+//    }
+//  }
+
   double worstPointDistance = SortPolicy::BestDistance();
   double bestPointDistance = SortPolicy::WorstDistance();
 
@@ -264,7 +331,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   // candidates (for (1) and (2)).
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+    const double distance = distances(distances.n_rows - 1,
+        queryNode.Point(i));
     if (SortPolicy::IsBetter(distance, bestPointDistance))
       bestPointDistance = distance;
     if (SortPolicy::IsBetter(worstPointDistance, distance))
@@ -302,16 +370,21 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
       worstChildBound : worstPointDistance;
 
   // This is bound (2).
-  const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
-      2 * queryMaxDescendantDistance);
+  const double secondBound = SortPolicy::CombineWorst(
+      SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance),
+      queryNode.FurthestPointDistance());
 
   // Bound (3) is bestAdjustedChildBound.
 
   // Bounds (4) and (5) are the parent bounds.
   const double fourthBound = (queryNode.Parent() != NULL) ?
       queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
-  const double fifthBound = (queryNode.Parent() != NULL) ?
-      queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
+//  const double fifthBound = (queryNode.Parent() != NULL) ?
+//      queryNode.Parent()->Stat().SecondBound() -
+//      queryNode.Parent()->FurthestDescendantDistance() -
+//      queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance +
+//      queryNode.FurthestPointDistance() + queryNode.ParentDistance() :
+//      SortPolicy::WorstDistance();
 
   // Now, we will take the best of these.  Unfortunately due to the way
   // IsBetter() is defined, this sort of has to be a little ugly.
@@ -326,16 +399,16 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   const double interB =
       (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
       bestAdjustedChildBound : secondBound;
-  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
-      fifthBound;
+//  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
+//      fifthBound;
 
   // Update the first and second bounds of the node.
   queryNode.Stat().FirstBound() = interA;
-  queryNode.Stat().SecondBound() = interC;
+  queryNode.Stat().SecondBound() = interB;
 
   // Update the actual bound of the node.
-  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
-      interC;
+  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB :
+      interB;
 
   return queryNode.Stat().Bound();
 }



More information about the mlpack-git mailing list