[mlpack-svn] r14999 - mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 2 00:13:36 EDT 2013


Author: rcurtin
Date: 2013-05-02 00:13:35 -0400 (Thu, 02 May 2013)
New Revision: 14999

Modified:
   mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Revert to 1.0.4 NeighborSearchRules bounds because the current ones aren't quite
as good.


Modified: mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2013-05-02 03:14:47 UTC (rev 14998)
+++ mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2013-05-02 04:13:35 UTC (rev 14999)
@@ -51,10 +51,6 @@
                   TreeType& referenceNode,
                   TreeType& referenceChildNode,
                   const double baseCaseResult) const;
-  double PrescoreQ(TreeType& queryNode,
-                   TreeType& queryChildNode,
-                   TreeType& referenceNode,
-                   const double baseCaseResult) const;
 
   /**
    * Get the score for recursion order.  A low score indicates priority for
@@ -151,11 +147,6 @@
   MetricType& metric;
 
   /**
-   * Recalculate the bound for a given query node.
-   */
-  double CalculateBound(TreeType& queryNode) const;
-
-  /**
    * Insert a point into the neighbors and distances matrices; this is a helper
    * function.
    *

Modified: mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-05-02 03:14:47 UTC (rev 14998)
+++ mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-05-02 04:13:35 UTC (rev 14999)
@@ -58,7 +58,7 @@
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
-  const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+  size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))
@@ -77,24 +77,35 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode, &referenceChildNode, baseCaseResult);
 
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = SortPolicy::WorstDistance();
+  double childBound = SortPolicy::WorstDistance();
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
 
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (SortPolicy::IsBetter(bound, pointBound))
+      pointBound = bound;
+  }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::PrescoreQ(
-    TreeType& queryNode,
-    TreeType& queryChildNode,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
-{
-  const double distance = SortPolicy::BestNodeToNodeDistance(&referenceNode,
-      &queryNode, &queryChildNode, baseCaseResult);
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (SortPolicy::IsBetter(bound, childBound))
+      childBound = bound;
+  }
 
   // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -150,8 +161,35 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode);
 
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = SortPolicy::WorstDistance();
+  double childBound = SortPolicy::WorstDistance();
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (SortPolicy::IsBetter(bound, pointBound))
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (SortPolicy::IsBetter(bound, childBound))
+      childBound = bound;
+  }
+
   // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -165,8 +203,35 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode, baseCaseResult);
 
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = SortPolicy::WorstDistance();
+  double childBound = SortPolicy::WorstDistance();
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (SortPolicy::IsBetter(bound, pointBound))
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (SortPolicy::IsBetter(bound, childBound))
+      childBound = bound;
+  }
+
   // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -180,118 +245,37 @@
   if (oldScore == DBL_MAX)
     return oldScore;
 
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = SortPolicy::WorstDistance();
+  double childBound = SortPolicy::WorstDistance();
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
 
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-// Calculate the bound for a given query node in its current state and update
-// it.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-    CalculateBound(TreeType& queryNode) const
-{
-  // We have five possible bounds, and we must take the best of them all.  We
-  // don't use min/max here, but instead "best/worst", because this is general
-  // to the nearest-neighbors/furthest-neighbors cases.  For nearest neighbors,
-  // min = best, max = worst.
-  //
-  // (1) worst ( worst_{all points p in queryNode} D_p[k],
-  //             worst_{all children c in queryNode} B(c) );
-  // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
-  //        worst descendant distance;
-  // (3) best_{all children c in queryNode} B(c) +
-  //      2 ( worst descendant distance of queryNode -
-  //          worst descendant distance of c );
-  // (4) B_1(parent of queryNode)
-  // (5) B_2(parent of queryNode);
-  //
-  // D_p[k] is the current k'th candidate distance for point p.
-  // So we will loop over the points in queryNode and the children in queryNode
-  // to calculate all five of these quantities.
-
-  double worstPointDistance = SortPolicy::BestDistance();
-  double bestPointDistance = SortPolicy::WorstDistance();
-
-  // Loop over all points in this node to find the best and worst distance
-  // candidates (for (1) and (2)).
+  // Find the bound of the points contained in this node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
-    if (SortPolicy::IsBetter(distance, bestPointDistance))
-      bestPointDistance = distance;
-    if (SortPolicy::IsBetter(worstPointDistance, distance))
-      worstPointDistance = distance;
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (SortPolicy::IsBetter(bound, pointBound))
+      pointBound = bound;
   }
 
-  // Loop over all the children in this node to find the worst bound (for (1))
-  // and the best bound with the correcting factor for descendant distances (for
-  // (3)).
-  double worstChildBound = SortPolicy::BestDistance();
-  double bestAdjustedChildBound = SortPolicy::WorstDistance();
-  const double queryMaxDescendantDistance =
-      queryNode.FurthestDescendantDistance();
-
+  // Find the bound of the children.
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
-    const double firstBound = queryNode.Child(i).Stat().FirstBound();
-    const double secondBound = queryNode.Child(i).Stat().SecondBound();
-    const double childMaxDescendantDistance =
-        queryNode.Child(i).FurthestDescendantDistance();
-
-    if (SortPolicy::IsBetter(worstChildBound, firstBound))
-      worstChildBound = firstBound;
-
-    // Now calculate adjustment for maximum descendant distances.
-    const double adjustedBound = SortPolicy::CombineWorst(secondBound,
-        2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
-    if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
-      bestAdjustedChildBound = adjustedBound;
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (SortPolicy::IsBetter(bound, childBound))
+      childBound = bound;
   }
 
-  // This is bound (1).
-  const double firstBound =
-      (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
-      worstChildBound : worstPointDistance;
+  // Update our bound.
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  const double bestDistance = queryNode.Stat().Bound();
 
-  // This is bound (2).
-  const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
-      2 * queryMaxDescendantDistance);
-
-  // Bound (3) is bestAdjustedChildBound.
-
-  // Bounds (4) and (5) are the parent bounds.
-  const double fourthBound = (queryNode.Parent() != NULL) ?
-      queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
-  const double fifthBound = (queryNode.Parent() != NULL) ?
-      queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
-
-  // Now, we will take the best of these.  Unfortunately due to the way
-  // IsBetter() is defined, this sort of has to be a little ugly.
-  // The variable interA represents the first bound (B_1), which is the worst
-  // candidate distance of any descendants of this node.
-  // The variable interC represents the second bound (B_2), which is a bound on
-  // the worst distance of any descendants of this node assembled using the best
-  // descendant candidate distance modified using the furthest descendant
-  // distance.
-  const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
-      firstBound : fourthBound;
-  const double interB =
-      (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
-      bestAdjustedChildBound : secondBound;
-  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
-      fifthBound;
-
-  // Update the first and second bounds of the node.
-  queryNode.Stat().FirstBound() = interA;
-  queryNode.Stat().SecondBound() = interC;
-
-  // Update the actual bound of the node.
-  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
-      interC;
-
-  return queryNode.Stat().Bound();
+  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
 }
 
 /**




More information about the mlpack-svn mailing list