[mlpack-svn] r14376 - in mlpack/trunk/src/mlpack/methods/neighbor_search: . sort_policies

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Feb 22 18:30:47 EST 2013


Author: rcurtin
Date: 2013-02-22 18:30:47 -0500 (Fri, 22 Feb 2013)
New Revision: 14376

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
Log:
Use the bound in the ICML paper.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-02-22 23:30:47 UTC (rev 14376)
@@ -203,76 +203,84 @@
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
     CalculateBound(TreeType& queryNode) const
 {
-  double pointBound = SortPolicy::BestDistance();
-  double childBound = SortPolicy::BestDistance();
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+  // We have four possible bounds, and we must take the best of them all.  We
+  // don't use min/max here, but instead "best/worst", because this is general
+  // to the nearest-neighbors/furthest-neighbors cases.  For nearest neighbors,
+  // min = best, max = worst.
+  //
+  // (1) worst ( worst_{all points p in queryNode} D_p[k],
+  //             worst_{all children c in queryNode} B(c) );
+  // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
+  //        worst descendant distance;
+  // (3) best_{all children c in queryNode} B(c) +
+  //      2 ( worst descendant distance of queryNode -
+  //          worst descendant distance of c );
+  // (4) B(parent of queryNode);
+  //
+  // D_p[k] is the current k'th candidate distance for point p.
+  // So we will loop over the points in queryNode and the children in queryNode
+  // to calculate all four of these quantities.
 
-  // Find the bound of the points contained in this node.
+  double worstPointDistance = SortPolicy::BestDistance();
+  double bestPointDistance = SortPolicy::WorstDistance();
+
+  // Loop over all points in this node to find the best and worst distance
+  // candidates (for (1) and (2)).
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i));
-    if (SortPolicy::IsBetter(pointBound, bound))
-      pointBound = bound;
+    const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+    if (SortPolicy::IsBetter(distance, bestPointDistance))
+      bestPointDistance = distance;
+    if (SortPolicy::IsBetter(worstPointDistance, distance))
+      worstPointDistance = distance;
   }
 
-  // Find the bound of the children.
+  // Loop over all the children in this node to find the worst bound (for (1))
+  // and the best bound with the correcting factor for descendant distances (for
+  // (3)).
+  double worstChildBound = SortPolicy::BestDistance();
+  double bestAdjustedChildBound = SortPolicy::WorstDistance();
+  const double queryMaxDescendantDistance =
+      queryNode.FurthestDescendantDistance();
+
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
     const double bound = queryNode.Child(i).Stat().Bound();
-    if (SortPolicy::IsBetter(childBound, bound))
-      childBound = bound;
-  }
+    const double childMaxDescendantDistance =
+        queryNode.Child(i).FurthestDescendantDistance();
 
-  // If there are no points, then break; the bound must be the child bound.
-  if (queryNode.NumPoints() == 0)
-    return childBound;
+    if (SortPolicy::IsBetter(worstChildBound, bound))
+      worstChildBound = bound;
 
-  // If there are no children, then break; the bound must be the point bound.
-  if (queryNode.NumChildren() == 0)
-    return pointBound;
+    // Now calculate adjustment for maximum descendant distances.
+    const double adjustedBound = SortPolicy::CombineWorst(bound,
+        2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
+    if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
+      bestAdjustedChildBound = adjustedBound;
+  }
 
-//  Log::Debug << "Point bound " << pointBound << std::endl;
-//  Log::Debug << "Child bound " << childBound << std::endl;
-//  Log::Debug << "Furthest descendant distance " << maxDescendantDistance <<
-//      std::endl;
+  // This is bound (1).
+  const double firstBound =
+      (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
+      worstChildBound : worstPointDistance;
 
-  // If the bound of the children is uninitialized
-  // (SortPolicy::WorstDistance()), then maybe we can create a bound for the
-  // children.  But this requires a point bound to exist.
+  // This is bound (2).
+  const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
+      2 * queryMaxDescendantDistance);
 
-  // It is possible that we could calculate a better bound for the children.
-  if (pointBound != SortPolicy::WorstDistance())
-  {
-    const double pointChildBound = pointBound + maxDescendantDistance;
-//    Log::Debug << "Point-child bound is " << pointChildBound << std::endl;
+  // Bound (3) is bestAdjustedChildBound.
+  const double fourthBound = (queryNode.Parent() != NULL) ?
+      queryNode.Parent()->Stat().Bound() : SortPolicy::WorstDistance();
 
-    if (SortPolicy::IsBetter(pointChildBound, childBound))
-    {
-      // The calculated bound is a tighter bound than the existing child bounds.
-      // Update all of the child bounds to this new, tighter bound.
-      for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-      {
-//        Log::Debug << "Update child " << i << " bound from " <<
-//            queryNode.Child(i).Stat().Bound() << " to " << pointChildBound <<
-//            std::endl;
-        if (SortPolicy::IsBetter(pointChildBound,
-            queryNode.Child(i).Stat().Bound()))
-          queryNode.Child(i).Stat().Bound() = pointChildBound;
-//        else
-//          Log::Debug << "Did not update child!\n";
-      }
+  // Now, we will take the best of these.  Unfortunately due to the way
+  // IsBetter() is defined, this sort of has to be a little ugly...
+  const double interA = (SortPolicy::IsBetter(firstBound, secondBound)) ?
+      firstBound : secondBound;
+  const double interB =
+      (SortPolicy::IsBetter(bestAdjustedChildBound, fourthBound)) ?
+      bestAdjustedChildBound : fourthBound;
 
-      childBound = pointChildBound;
-    }
-  }
-
-  // Return the worse of the two bounds.
-  if (SortPolicy::IsBetter(childBound, pointBound))
-    return pointBound;
-  else
-    return childBound;
+  return (SortPolicy::IsBetter(interA, interB)) ? interA : interB;
 }
 
 /**

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp	2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp	2013-02-22 23:30:47 UTC (rev 14376)
@@ -127,6 +127,12 @@
    * @return DBL_MAX
    */
   static inline double BestDistance() { return DBL_MAX; }
+
+  /**
+   * Return the worst combination of the two distances.
+   */
+  static inline double CombineWorst(const double a, const double b)
+  { return std::max(a - b, 0.0); }
 };
 
 }; // namespace neighbor

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp	2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp	2013-02-22 23:30:47 UTC (rev 14376)
@@ -130,6 +130,15 @@
    * @return 0.0
    */
   static inline double BestDistance() { return 0.0; }
+
+  /**
+   * Return the worst combination of the two distances.
+   */
+  static inline double CombineWorst(const double a, const double b)
+  {
+    if (a == DBL_MAX || b == DBL_MAX)
+      return DBL_MAX;
+    return a + b; }
 };
 
 }; // namespace neighbor




More information about the mlpack-svn mailing list