[mlpack-svn] r13353 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 17:19:23 EDT 2012
Author: rcurtin
Date: 2012-08-06 17:19:23 -0400 (Mon, 06 Aug 2012)
New Revision: 13353
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Update single tree cover tree traverser to be smarter about recursion order.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-06 20:48:02 UTC (rev 13352)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-06 21:19:23 UTC (rev 13353)
@@ -35,44 +35,58 @@
// recursive implementation).
std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
std::queue<size_t> parentPoints; // For if this tree has self-children.
+ std::queue<double> pointScores;
pointQueue.push(&referenceNode);
parentPoints.push(size_t() - 1); // Invalid value.
+ pointScores.push(0.0); // Cannot be pruned.
while (!pointQueue.empty())
{
CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
pointQueue.front();
+ const size_t parent = parentPoints.front();
+ const double score = pointScores.front();
+ const size_t point = node->Point(); // The point held by this node.
+
pointQueue.pop();
+ parentPoints.pop();
+ pointScores.pop();
- // Check if we can prune this node.
- if (rule.CanPrune(queryIndex, *node))
+ // See if this point should still be recursed into.
+ if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
{
- parentPoints.pop(); // Pop the parent point off.
-
++numPrunes;
- continue;
+ continue; // Pruned!
}
- // If this tree type has self-children, we need to make sure we don't run
- // the base case if the parent already had it run.
- size_t baseCaseStart = 0;
- if (parentPoints.front() == node->Point(0))
- baseCaseStart = 1; // Skip base case we've already evaluated.
+ // Evaluate the base case, but only if this node is not holding the same
+ // point as its parent.
+ if (parent != point)
+ rule.BaseCase(queryIndex, point);
- parentPoints.pop();
+ // Now get the scores for recursion.
+ arma::vec scores(node->NumChildren());
- // First run the base case for any points this node might hold.
- for (size_t i = baseCaseStart; i < node->NumPoints(); ++i)
- rule.BaseCase(queryIndex, node->Point(i));
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ scores[i] = rule.Score(queryIndex, node->Child(i));
- // Now push children (and their parent points) into the FIFO. Maybe it
- // would be better to push these back in a particular order.
- const size_t parentPoint = node->Point(0);
- for (size_t i = 0; i < node->NumChildren(); ++i)
+ // Now sort by distance (smallest first).
+ arma::uvec order = arma::sort_index(scores);
+
+ for (size_t i = 0; i < order.n_elem; ++i)
{
- pointQueue.push(&(node->Child(i)));
- parentPoints.push(parentPoint);
+ // Ensure we haven't hit the limit yet.
+ const double childScore = scores[order[i]];
+ if (childScore == DBL_MAX)
+ {
+ numPrunes += (order.n_elem - i); // Prune the rest of the children.
+ break; // Go on to next point.
+ }
+
+ pointQueue.push(&node->Child(order[i]));
+ parentPoints.push(point);
+ pointScores.push(childScore);
}
}
}
More information about the mlpack-svn
mailing list