[mlpack-svn] r13364 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 18:40:38 EDT 2012
Author: rcurtin
Date: 2012-08-07 18:40:37 -0400 (Tue, 07 Aug 2012)
New Revision: 13364
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Revamp single tree traverser... again. For now it has debugging output.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-07 22:37:20 UTC (rev 13363)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-07 22:40:37 UTC (rev 13364)
@@ -16,7 +16,30 @@
namespace mlpack {
namespace tree {
+//! This is the structure the priority queue will use for traversal.
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct CoverTreeQueueEntry
+{
+ //! The node this entry refers to.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+ //! The score of the node.
+ double score;
+ //! The index of the parent node.
+ size_t parent;
+ //! The base case evaluation (-1.0 if it has not been performed).
+ double baseCase;
+
+ //! Comparison operator.
+ bool operator<(const CoverTreeQueueEntry& other) const
+ {
+ return ((node->Scale() < other.node->Scale()) ||
+ ((node->Scale() == other.node->Scale()) && (score < other.score)));
+ }
+};
+
+
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
template<typename RuleType>
CoverTree<MetricType, RootPointPolicy, StatisticType>::
SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
@@ -33,78 +56,141 @@
{
// This is a non-recursive implementation (which should be faster than a
// recursive implementation).
- std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
- std::queue<double> pointScores;
- std::queue<double> baseCaseResults;
+ typedef CoverTreeQueueEntry<MetricType, RootPointPolicy, StatisticType>
+ QueueType;
+ std::priority_queue<QueueType> pointQueue;
- pointQueue.push(&referenceNode);
- pointScores.push(0.0); // Cannot be pruned.
+ // Unsorted list of leaves we have to look through.
+ std::queue<QueueType> leafQueue;
- // Evaluate first base case.
- baseCaseResults.push(rule.BaseCase(queryIndex, referenceNode.Point()));
+ QueueType first;
+ first.node = &referenceNode;
+ first.score = 0.0;
+ first.parent = (size_t() - 1); // Invalid index.
+ first.baseCase = -1.0;
+ pointQueue.push(first);
+ Log::Warn << "Beginning recursion for " << queryIndex << std::endl;
+
while (!pointQueue.empty())
{
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
- pointQueue.front();
- const double score = pointScores.front();
- const double baseCase = baseCaseResults.front();
+ QueueType frame = pointQueue.top();
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+ const double score = frame.score;
+ const size_t parent = frame.parent;
+ const size_t point = node->Point();
+ double baseCase = frame.baseCase;
+
+ Log::Debug << "Current point is " << node->Point() << " and scale "
+ << node->Scale() << ".\n";
+
pointQueue.pop();
- pointScores.pop();
- baseCaseResults.pop();
// First we (re)calculate the score of this node to find if we can prune it.
+ Log::Debug << "Before rescoring, score is " << score << " and base case of "
+ << "parent is " << baseCase << std::endl;
double actualScore = rule.Rescore(queryIndex, *node, score);
+// Log::Debug << "Actual score is " << actualScore << ".\n";
+
if (actualScore == DBL_MAX)
{
// Prune this node.
+ Log::Debug << "Pruning after re-scoring (original score " << score << ")."
+ << std::endl;
++numPrunes;
continue; // Skip to next in queue.
}
- // The base case is already evaluated. So now we need to find out how to
- // recurse into the children.
- arma::vec baseCases(node->NumChildren());
- arma::vec childScores(node->NumChildren());
-
- // We already know the base case for the self child (that's the same as the
- // base case for us).
- baseCases[0] = baseCase;
- if (node->Child(0).NumChildren() == 0)
- childScores[0] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+ // If we are a self-child, the base case has already been evaluated.
+ if (point != parent)
+ {
+ baseCase = rule.BaseCase(queryIndex, point);
+ Log::Debug << "Base case between " << queryIndex << " and " << point <<
+ " evaluates to " << baseCase << ".\n";
+ }
else
- childScores[0] = rule.Score(queryIndex, node->Child(0), baseCase);
-
- // Fill the rest of the children.
- for (size_t i = 1; i < node->NumChildren(); ++i)
{
- baseCases[i] = rule.BaseCase(queryIndex, node->Child(i).Point());
- if (node->Child(i).NumChildren() == 0)
- childScores[i] = DBL_MAX; // Do not recurse into leaves (unnecessary).
- else
- childScores[i] = rule.Score(queryIndex, node->Child(i), baseCases[i]);
+ Log::Debug << "Base case between " << queryIndex << " and " << point <<
+ " already known to be " << baseCase << ".\n";
}
- // Now sort by score.
- arma::uvec order = arma::sort_index(childScores);
+ // Create the score for the children.
+ double childScore = rule.Score(queryIndex, *node, baseCase);
- // Now add each to the queue.
- for (size_t i = 0; i < order.n_elem; ++i)
+ // Now if the childScore is DBL_MAX we can prune all children. In this
+ // recursion setup pruning is all or nothing for children.
+ if (childScore == DBL_MAX)
{
- // Ensure we haven't hit the limit yet.
- const double childScore = childScores[order[i]];
- if (childScore == DBL_MAX)
+ Log::Debug << "Pruning all children.\n";
+ numPrunes += node->NumChildren();
+ }
+ else
+ {
+ for (size_t i = 0; i < node->NumChildren(); ++i)
{
- numPrunes += (order.n_elem - i); // Prune the rest of the children.
- break; // Go on to next point.
+ QueueType newFrame;
+ newFrame.node = &node->Child(i);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.parent = point;
+
+ // Put it into the regular priority queue if it has children.
+ if (newFrame.node->NumChildren() > 0)
+ {
+ Log::Debug << "Push back child " << i << ": point " <<
+ newFrame.node->Point() << ", scale " << newFrame.node->Scale()
+ << ".\n";
+ pointQueue.push(newFrame);
+ }
+ else if ((newFrame.node->NumChildren() == 0) && (i > 0))
+ {
+ // We don't add the self-leaf to the leaf queue (it can't possibly
+ // help).
+ Log::Debug << "Push back child " << i << ": point " <<
+ newFrame.node->Point() << ", scale " << newFrame.node->Scale()
+ << ".\n";
+ leafQueue.push(newFrame);
+ }
+ else
+ {
+ Log::Debug << "Prune self-leaf point " << point << ".\n";
+ ++numPrunes;
+ }
}
+ }
+ }
- pointQueue.push(&node->Child(order[i]));
- pointScores.push(childScore);
- baseCaseResults.push(baseCases[order[i]]);
+ // Now look through all the leaves.
+ while (!leafQueue.empty())
+ {
+ QueueType frame = leafQueue.front();
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+ const double score = frame.score;
+ const size_t point = node->Point();
+
+ Log::Debug << "Inspecting leaf " << point << " with score " << score <<
+ "\n";
+
+ leafQueue.pop();
+
+ // First, recalculate the score of this node to find if we can prune it.
+ double actualScore = rule.Rescore(queryIndex, *node, score);
+
+ if (actualScore == DBL_MAX)
+ {
+ Log::Debug << "Pruned before base case.\n";
+ ++numPrunes;
+ continue;
}
+
+ // There are no self-leaves in this queue, so the only thing left to do is
+ // evaluate the base case.
+ const double baseCase = rule.BaseCase(queryIndex, point);
+ Log::Debug << "Base case between " << queryIndex << " and " << point <<
+ " evaluates to " << baseCase << ".\n";
}
}
More information about the mlpack-svn
mailing list