[mlpack-svn] r13357 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 01:27:29 EDT 2012


Author: rcurtin
Date: 2012-08-07 01:27:29 -0400 (Tue, 07 Aug 2012)
New Revision: 13357

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Rewrite the single tree traverser for the cover tree.  It works, now it needs to
be sped up.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-07 05:27:05 UTC (rev 13356)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-07 05:27:29 UTC (rev 13357)
@@ -34,50 +34,67 @@
   // This is a non-recursive implementation (which should be faster than a
   // recursive implementation).
   std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
-  std::queue<size_t> parentPoints; // For if this tree has self-children.
   std::queue<double> pointScores;
+  std::queue<double> baseCaseResults;
 
   pointQueue.push(&referenceNode);
-  parentPoints.push(size_t() - 1); // Invalid value.
   pointScores.push(0.0); // Cannot be pruned.
 
+  // Evaluate first base case.
+  baseCaseResults.push(rule.BaseCase(queryIndex, referenceNode.Point()));
+
   while (!pointQueue.empty())
   {
     CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
         pointQueue.front();
-    const size_t parent = parentPoints.front();
     const double score = pointScores.front();
-    const size_t point = node->Point(); // The point held by this node.
+    const double baseCase = baseCaseResults.front();
 
     pointQueue.pop();
-    parentPoints.pop();
     pointScores.pop();
+    baseCaseResults.pop();
 
-    // See if this point should still be recursed into.
-    if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
+    // First we (re)calculate the score of this node to find if we can prune it.
+    double actualScore = rule.Rescore(queryIndex, *node, score);
+
+    if (actualScore == DBL_MAX)
     {
+      // Prune this node.
       ++numPrunes;
-      continue; // Pruned!
+      continue; // Skip to next in queue.
     }
 
-    // Evaluate the base case, but only if this node is not holding the same
-    // point as its parent.
-    if (parent != point)
-      rule.BaseCase(queryIndex, point);
+    // The base case is already evaluated.  So now we need to find out how to
+    // recurse into the children.
+    arma::vec baseCases(node->NumChildren());
+    arma::vec childScores(node->NumChildren());
 
-    // Now get the scores for recursion.
-    arma::vec scores(node->NumChildren());
+    // We already know the base case for the self child (that's the same as the
+    // base case for us).
+    baseCases[0] = baseCase;
+    if (node->Child(0).NumChildren() == 0)
+      childScores[0] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+    else
+      childScores[0] = rule.Score(queryIndex, node->Child(0), baseCase);
 
-    for (size_t i = 0; i < node->NumChildren(); ++i)
-      scores[i] = rule.Score(queryIndex, node->Child(i));
+    // Fill the rest of the children.
+    for (size_t i = 1; i < node->NumChildren(); ++i)
+    {
+      baseCases[i] = rule.BaseCase(queryIndex, node->Child(i).Point());
+      if (node->Child(i).NumChildren() == 0)
+        childScores[i] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+      else
+        childScores[i] = rule.Score(queryIndex, node->Child(i), baseCases[i]);
+    }
 
-    // Now sort by distance (smallest first).
-    arma::uvec order = arma::sort_index(scores);
+    // Now sort by score.
+    arma::uvec order = arma::sort_index(childScores);
 
+    // Now add each to the queue.
     for (size_t i = 0; i < order.n_elem; ++i)
     {
       // Ensure we haven't hit the limit yet.
-      const double childScore = scores[order[i]];
+      const double childScore = childScores[order[i]];
       if (childScore == DBL_MAX)
       {
         numPrunes += (order.n_elem - i); // Prune the rest of the children.
@@ -85,8 +102,8 @@
       }
 
       pointQueue.push(&node->Child(order[i]));
-      parentPoints.push(point);
       pointScores.push(childScore);
+      baseCaseResults.push(baseCases[order[i]]);
     }
   }
 }




More information about the mlpack-svn mailing list