[mlpack-svn] r13353 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 17:19:23 EDT 2012


Author: rcurtin
Date: 2012-08-06 17:19:23 -0400 (Mon, 06 Aug 2012)
New Revision: 13353

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Update single tree cover tree traverser to be smarter about recursion order.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-06 20:48:02 UTC (rev 13352)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-06 21:19:23 UTC (rev 13353)
@@ -35,44 +35,58 @@
   // recursive implementation).
   std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
   std::queue<size_t> parentPoints; // For if this tree has self-children.
+  std::queue<double> pointScores;
 
   pointQueue.push(&referenceNode);
   parentPoints.push(size_t() - 1); // Invalid value.
+  pointScores.push(0.0); // Cannot be pruned.
 
   while (!pointQueue.empty())
   {
     CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
         pointQueue.front();
+    const size_t parent = parentPoints.front();
+    const double score = pointScores.front();
+    const size_t point = node->Point(); // The point held by this node.
+
     pointQueue.pop();
+    parentPoints.pop();
+    pointScores.pop();
 
-    // Check if we can prune this node.
-    if (rule.CanPrune(queryIndex, *node))
+    // See if this point should still be recursed into.
+    if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
     {
-      parentPoints.pop(); // Pop the parent point off.
-
       ++numPrunes;
-      continue;
+      continue; // Pruned!
     }
 
-    // If this tree type has self-children, we need to make sure we don't run
-    // the base case if the parent already had it run.
-    size_t baseCaseStart = 0;
-    if (parentPoints.front() == node->Point(0))
-      baseCaseStart = 1; // Skip base case we've already evaluated.
+    // Evaluate the base case, but only if this node is not holding the same
+    // point as its parent.
+    if (parent != point)
+      rule.BaseCase(queryIndex, point);
 
-    parentPoints.pop();
+    // Now get the scores for recursion.
+    arma::vec scores(node->NumChildren());
 
-    // First run the base case for any points this node might hold.
-    for (size_t i = baseCaseStart; i < node->NumPoints(); ++i)
-      rule.BaseCase(queryIndex, node->Point(i));
+    for (size_t i = 0; i < node->NumChildren(); ++i)
+      scores[i] = rule.Score(queryIndex, node->Child(i));
 
-    // Now push children (and their parent points) into the FIFO.  Maybe it
-    // would be better to push these back in a particular order.
-    const size_t parentPoint = node->Point(0);
-    for (size_t i = 0; i < node->NumChildren(); ++i)
+    // Now sort by distance (smallest first).
+    arma::uvec order = arma::sort_index(scores);
+
+    for (size_t i = 0; i < order.n_elem; ++i)
     {
-      pointQueue.push(&(node->Child(i)));
-      parentPoints.push(parentPoint);
+      // Ensure we haven't hit the limit yet.
+      const double childScore = scores[order[i]];
+      if (childScore == DBL_MAX)
+      {
+        numPrunes += (order.n_elem - i); // Prune the rest of the children.
+        break; // Go on to next point.
+      }
+
+      pointQueue.push(&node->Child(order[i]));
+      parentPoints.push(point);
+      pointScores.push(childScore);
     }
   }
 }




More information about the mlpack-svn mailing list