[mlpack-svn] r13357 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 01:27:29 EDT 2012
Author: rcurtin
Date: 2012-08-07 01:27:29 -0400 (Tue, 07 Aug 2012)
New Revision: 13357
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Rewrite the single tree traverser for the cover tree. It works, now it needs to
be sped up.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-07 05:27:05 UTC (rev 13356)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-07 05:27:29 UTC (rev 13357)
@@ -34,50 +34,67 @@
// This is a non-recursive implementation (which should be faster than a
// recursive implementation).
std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
- std::queue<size_t> parentPoints; // For if this tree has self-children.
std::queue<double> pointScores;
+ std::queue<double> baseCaseResults;
pointQueue.push(&referenceNode);
- parentPoints.push(size_t() - 1); // Invalid value.
pointScores.push(0.0); // Cannot be pruned.
+ // Evaluate first base case.
+ baseCaseResults.push(rule.BaseCase(queryIndex, referenceNode.Point()));
+
while (!pointQueue.empty())
{
CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
pointQueue.front();
- const size_t parent = parentPoints.front();
const double score = pointScores.front();
- const size_t point = node->Point(); // The point held by this node.
+ const double baseCase = baseCaseResults.front();
pointQueue.pop();
- parentPoints.pop();
pointScores.pop();
+ baseCaseResults.pop();
- // See if this point should still be recursed into.
- if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
+ // First we (re)calculate the score of this node to find if we can prune it.
+ double actualScore = rule.Rescore(queryIndex, *node, score);
+
+ if (actualScore == DBL_MAX)
{
+ // Prune this node.
++numPrunes;
- continue; // Pruned!
+ continue; // Skip to next in queue.
}
- // Evaluate the base case, but only if this node is not holding the same
- // point as its parent.
- if (parent != point)
- rule.BaseCase(queryIndex, point);
+ // The base case is already evaluated. So now we need to find out how to
+ // recurse into the children.
+ arma::vec baseCases(node->NumChildren());
+ arma::vec childScores(node->NumChildren());
- // Now get the scores for recursion.
- arma::vec scores(node->NumChildren());
+ // We already know the base case for the self child (that's the same as the
+ // base case for us).
+ baseCases[0] = baseCase;
+ if (node->Child(0).NumChildren() == 0)
+ childScores[0] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+ else
+ childScores[0] = rule.Score(queryIndex, node->Child(0), baseCase);
- for (size_t i = 0; i < node->NumChildren(); ++i)
- scores[i] = rule.Score(queryIndex, node->Child(i));
+ // Fill the rest of the children.
+ for (size_t i = 1; i < node->NumChildren(); ++i)
+ {
+ baseCases[i] = rule.BaseCase(queryIndex, node->Child(i).Point());
+ if (node->Child(i).NumChildren() == 0)
+ childScores[i] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+ else
+ childScores[i] = rule.Score(queryIndex, node->Child(i), baseCases[i]);
+ }
- // Now sort by distance (smallest first).
- arma::uvec order = arma::sort_index(scores);
+ // Now sort by score.
+ arma::uvec order = arma::sort_index(childScores);
+ // Now add each to the queue.
for (size_t i = 0; i < order.n_elem; ++i)
{
// Ensure we haven't hit the limit yet.
- const double childScore = scores[order[i]];
+ const double childScore = childScores[order[i]];
if (childScore == DBL_MAX)
{
numPrunes += (order.n_elem - i); // Prune the rest of the children.
@@ -85,8 +102,8 @@
}
pointQueue.push(&node->Child(order[i]));
- parentPoints.push(point);
pointScores.push(childScore);
+ baseCaseResults.push(baseCases[order[i]]);
}
}
}
More information about the mlpack-svn
mailing list