[mlpack-git] mlpack-1.0.x: Backport trunk fixes for NeighborSearch. (9763294)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jan 7 11:57:22 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/0000000000000000000000000000000000000000...904762495c039e345beba14c1142fd719b3bd50e

>---------------------------------------------------------------

commit 9763294d98771f1bb9894cd450ae9c1fba35896b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Aug 25 21:49:46 2014 +0000

    Backport trunk fixes for NeighborSearch.


>---------------------------------------------------------------

9763294d98771f1bb9894cd450ae9c1fba35896b
 .../neighbor_search/neighbor_search_impl.hpp       |  21 +--
 .../neighbor_search/neighbor_search_rules_impl.hpp | 209 ++++++++-------------
 2 files changed, 79 insertions(+), 151 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index c8ee610..33b2577 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -4,21 +4,6 @@
  *
  * Implementation of Neighbor-Search class to perform all-nearest-neighbors on
  * two specified data sets.
- *
- * This file is part of MLPACK 1.0.9.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
  */
 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
@@ -267,7 +252,6 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
   else if (singleMode)
   {
-
     // The search doesn't work if the root node is also a leaf node.
     // if this is the case, it is suggested that you use the naive method.
     assert(!(referenceTree->IsLeaf()));
@@ -278,8 +262,11 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
+
+    Log::Info << rules.Scores() << " node combinations were scored.\n";
+    Log::Info << rules.BaseCases() << " base cases were calculated.\n";
   }
-  else // Dual-tree recursion.referenceTree
+  else // Dual-tree recursion.
   {
     // Create the traverser.
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 5d1cc99..bdafbb1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -3,21 +3,6 @@
  * @author Ryan Curtin
  *
  * Implementation of NearestNeighborRules.
- *
- * This file is part of MLPACK 1.0.9.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
  */
 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
@@ -328,136 +313,92 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
     CalculateBound(TreeType& queryNode) const
 {
-  // We have five possible bounds, and we must take the best of them all.  We
-  // don't use min/max here, but instead "best/worst", because this is general
-  // to the nearest-neighbors/furthest-neighbors cases.  For nearest neighbors,
-  // min = best, max = worst.
-  //
-  // (1) worst ( worst_{all points p in queryNode} D_p[k],
-  //             worst_{all children c in queryNode} B(c) );
-  // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
-  //        worst descendant distance;
-  // (3) best_{all children c in queryNode} B(c) +
-  //      2 ( worst descendant distance of queryNode -
-  //          worst descendant distance of c );
-  // (4) B_1(parent of queryNode)
-  // (5) B_2(parent of queryNode);
-  //
-  // D_p[k] is the current k'th candidate distance for point p.
-  // So we will loop over the points in queryNode and the children in queryNode
-  // to calculate all five of these quantities.
-
-  // Hm, can we populate our distances vector with estimates from the parent?
-  // This is written specifically for the cover tree and assumes only one point
-  // in a node.
-//  if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0)
-//  {
-//    size_t parentIndexStart = 0;
-//    for (size_t i = 0; i < neighbors.n_rows; ++i)
-//    {
-//      const double pointDistance = distances(i, queryNode.Point(0));
-//      if (pointDistance == DBL_MAX)
-//      {
-//      // Cool, can we take an estimate from the parent?
-//        const double parentWorstBound = distances(distances.n_rows - 1,
-//              queryNode.Parent()->Point(0));
-//        if (parentWorstBound != DBL_MAX)
-//        {
-//          const double parentAdjustedDistance = parentWorstBound +
-//              queryNode.ParentDistance();
-//          distances(i, queryNode.Point(0)) = parentAdjustedDistance;
-//        }
-//      }
-//    }
-//  }
-
-  double worstPointDistance = SortPolicy::BestDistance();
-  double bestPointDistance = SortPolicy::WorstDistance();
-
-  // Loop over all points in this node to find the best and worst distance
-  // candidates (for (1) and (2)).
+  // This is an adapted form of the B(N_q) function in the paper
+  // ``Tree-Independent Dual-Tree Algorithms'' by Curtin et. al.; the goal is to
+  // place a bound on the worst possible distance a point combination could have
+  // to improve any of the current neighbor estimates.  If the best possible
+  // distance between two nodes is greater than this bound, then the node
+  // combination can be pruned (see Score()).
+
+  // There are a couple ways we can assemble a bound.  For simplicity, this is
+  // described for nearest neighbor search (SortPolicy = NearestNeighborSort),
+  // but the code that is written is adapted for whichever SortPolicy.
+
+  // First, we can consider the current worst neighbor candidate distance of any
+  // descendant point.  This is assembled with 'worstDistance' by looping
+  // through the points held by the query node, and then by taking the cached
+  // worst distance from any child nodes (Stat().FirstBound()).  This
+  // corresponds roughly to B_1(N_q) in the paper.
+
+  // The other way of bounding is to use the triangle inequality.  To do this,
+  // we find the current best kth-neighbor candidate distance of any descendant
+  // query point, and use the triangle inequality to place a bound on the
+  // distance that candidate would have to any other descendant query point.
+  // This corresponds roughly to B_2(N_q) in the paper, and is the bounding
+  // style for cover trees.
+
+  // Then, to assemble the final bound, since both bounds are valid, we simply
+  // take the better of the two.
+
+  double worstDistance = SortPolicy::BestDistance();
+  double bestDistance = SortPolicy::WorstDistance();
+
+  // Loop over points held in the node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    const double distance = distances(distances.n_rows - 1,
-        queryNode.Point(i));
-    if (SortPolicy::IsBetter(distance, bestPointDistance))
-      bestPointDistance = distance;
-    if (SortPolicy::IsBetter(worstPointDistance, distance))
-      worstPointDistance = distance;
+    const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+    if (SortPolicy::IsBetter(worstDistance, distance))
+      worstDistance = distance;
+    if (SortPolicy::IsBetter(distance, bestDistance))
+      bestDistance = distance;
   }
 
-  // Loop over all the children in this node to find the worst bound (for (1))
-  // and the best bound with the correcting factor for descendant distances (for
-  // (3)).
-  double worstChildBound = SortPolicy::BestDistance();
-  double bestAdjustedChildBound = SortPolicy::WorstDistance();
-  const double queryMaxDescendantDistance =
+  // Add triangle inequality adjustment to best distance.  It is possible this
+  // could be tighter for some certain types of trees.
+  bestDistance += queryNode.FurthestPointDistance() +
       queryNode.FurthestDescendantDistance();
 
+  // Loop over children of the node, and use their cached information to
+  // assemble bounds.
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
     const double firstBound = queryNode.Child(i).Stat().FirstBound();
-    const double secondBound = queryNode.Child(i).Stat().SecondBound();
-    const double childMaxDescendantDistance =
-        queryNode.Child(i).FurthestDescendantDistance();
-
-    if (SortPolicy::IsBetter(worstChildBound, firstBound))
-      worstChildBound = firstBound;
-
-    // Now calculate adjustment for maximum descendant distances.
-    const double adjustedBound = SortPolicy::CombineWorst(secondBound,
-        2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
-    if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
-      bestAdjustedChildBound = adjustedBound;
+    const double adjustedSecondBound = queryNode.Child(i).Stat().SecondBound() +
+        2 * (queryNode.FurthestDescendantDistance() -
+             queryNode.Child(i).FurthestDescendantDistance());
+
+    if (SortPolicy::IsBetter(worstDistance, firstBound))
+      worstDistance = firstBound;
+    if (SortPolicy::IsBetter(adjustedSecondBound, bestDistance))
+      bestDistance = adjustedSecondBound;
   }
 
-  // This is bound (1).
-  const double firstBound =
-      (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
-      worstChildBound : worstPointDistance;
-
-  // This is bound (2).
-  const double secondBound = SortPolicy::CombineWorst(
-      SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance),
-      queryNode.FurthestPointDistance());
-
-  // Bound (3) is bestAdjustedChildBound.
-
-  // Bounds (4) and (5) are the parent bounds.
-  const double fourthBound = (queryNode.Parent() != NULL) ?
-      queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
-//  const double fifthBound = (queryNode.Parent() != NULL) ?
-//      queryNode.Parent()->Stat().SecondBound() -
-//      queryNode.Parent()->FurthestDescendantDistance() -
-//      queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance +
-//      queryNode.FurthestPointDistance() + queryNode.ParentDistance() :
-//      SortPolicy::WorstDistance();
-
-  // Now, we will take the best of these.  Unfortunately due to the way
-  // IsBetter() is defined, this sort of has to be a little ugly.
-  // The variable interA represents the first bound (B_1), which is the worst
-  // candidate distance of any descendants of this node.
-  // The variable interC represents the second bound (B_2), which is a bound on
-  // the worst distance of any descendants of this node assembled using the best
-  // descendant candidate distance modified using the furthest descendant
-  // distance.
-  const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
-      firstBound : fourthBound;
-  const double interB =
-      (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
-      bestAdjustedChildBound : secondBound;
-//  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
-//      fifthBound;
-
-  // Update the first and second bounds of the node.
-  queryNode.Stat().FirstBound() = interA;
-  queryNode.Stat().SecondBound() = interB;
-
-  // Update the actual bound of the node.
-  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB :
-      interB;
-
-  return queryNode.Stat().Bound();
+  // Now consider the parent bounds.
+  if (queryNode.Parent() != NULL)
+  {
+    // The parent's worst distance bound implies that the bound for this node
+    // must be at least as good.  Thus, if the parent worst distance bound is
+    // better, then take it.
+    if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
+        worstDistance))
+      worstDistance = queryNode.Parent()->Stat().FirstBound();
+
+    // The parent's best distance bound implies that the bound for this node
+    // must be at least as good.  Thus, if the parent best distance bound is
+    // better, then take it.
+    if (SortPolicy::IsBetter(queryNode.Parent()->Stat().SecondBound(),
+        bestDistance))
+      bestDistance = queryNode.Parent()->Stat().SecondBound();
+  }
+
+  // Cache bounds for later.
+  queryNode.Stat().FirstBound() = worstDistance;
+  queryNode.Stat().SecondBound() = bestDistance;
+
+  if (SortPolicy::IsBetter(worstDistance, bestDistance))
+    return worstDistance;
+  else
+    return bestDistance;
 }
 
 /**



More information about the mlpack-git mailing list