[mlpack-svn] r13367 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 22:02:47 EDT 2012
Author: rcurtin
Date: 2012-08-07 22:02:47 -0400 (Tue, 07 Aug 2012)
New Revision: 13367
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Refactor single tree traverser again; don't use priority_queue which is slow.
Instead exploit some nice properties which allow us to use a map of vectors. It
seems as though this implementation is about twice as fast.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-08 01:55:23 UTC (rev 13366)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2012-08-08 02:02:47 UTC (rev 13367)
@@ -32,8 +32,7 @@
//! Comparison operator.
bool operator<(const CoverTreeQueueEntry& other) const
{
- return ((node->Scale() < other.node->Scale()) ||
- ((node->Scale() == other.node->Scale()) && (score < other.score)));
+ return (score < other.score);
}
};
@@ -58,139 +57,140 @@
// recursive implementation).
typedef CoverTreeQueueEntry<MetricType, RootPointPolicy, StatisticType>
QueueType;
- std::priority_queue<QueueType> pointQueue;
- // Unsorted list of leaves we have to look through.
- std::queue<QueueType> leafQueue;
+ // We will use this map as a priority queue. Each key represents the scale,
+ // and then the vector is all the nodes in that scale which need to be
+ // investigated. Because no point in a scale can add a point in its own
+ // scale, we know that the vector for each scale is final when we get to it.
+ // In addition, map is organized in such a way that rbegin() will return the
+ // largest scale.
+ std::map<int, std::vector<QueueType> > mapQueue;
- QueueType first;
- first.node = &referenceNode;
- first.score = 0.0;
- first.parent = (size_t() - 1); // Invalid index.
- first.baseCase = -1.0;
- pointQueue.push(first);
+ // Manually add the children of the first node. These cannot be pruned
+ // anyway.
+ double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
+
+ // Create the score for the children.
+ double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
- Log::Warn << "Beginning recursion for " << queryIndex << std::endl;
-
- while (!pointQueue.empty())
+ if (rootChildScore == DBL_MAX)
{
- QueueType frame = pointQueue.top();
+ numPrunes += referenceNode.NumChildren();
+ }
+ else
+ {
+ // Don't add the self-leaf.
+ size_t i = 0;
+ if (referenceNode.Child(0).NumChildren() == 0)
+ {
+ ++numPrunes;
+ i = 1;
+ }
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
- const double score = frame.score;
- const size_t parent = frame.parent;
- const size_t point = node->Point();
- double baseCase = frame.baseCase;
+ for (/* i was set above. */; i < referenceNode.NumChildren(); ++i)
+ {
+ QueueType newFrame;
+ newFrame.node = &referenceNode.Child(i);
+ newFrame.score = rootChildScore;
+ newFrame.baseCase = rootBaseCase;
+ newFrame.parent = referenceNode.Point();
- Log::Debug << "Current point is " << node->Point() << " and scale "
- << node->Scale() << ".\n";
+ // Put it into the map.
+ mapQueue[newFrame.node->Scale()].push_back(newFrame);
+ }
+ }
- pointQueue.pop();
+ // Now begin the iteration through the map.
+ typename std::map<int, std::vector<QueueType> >::reverse_iterator rit =
+ mapQueue.rbegin();
- // First we (re)calculate the score of this node to find if we can prune it.
- Log::Debug << "Before rescoring, score is " << score << " and base case of "
- << "parent is " << baseCase << std::endl;
- double actualScore = rule.Rescore(queryIndex, *node, score);
+ // We will treat the leaves differently (below).
+ while ((*rit).first != INT_MIN)
+ {
+ // Get a reference to the current scale.
+ std::vector<QueueType>& scaleVector = (*rit).second;
-// Log::Debug << "Actual score is " << actualScore << ".\n";
+ // Before beginning all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
- if (actualScore == DBL_MAX)
+ // Now loop over each element.
+ for (size_t i = 0; i < scaleVector.size(); ++i)
{
- // Prune this node.
- Log::Debug << "Pruning after re-scoring (original score " << score << ")."
- << std::endl;
- ++numPrunes;
- continue; // Skip to next in queue.
- }
+ // Get a reference to the current element.
+ const QueueType& frame = scaleVector.at(i);
- // If we are a self-child, the base case has already been evaluated.
- if (point != parent)
- {
- baseCase = rule.BaseCase(queryIndex, point);
- Log::Debug << "Base case between " << queryIndex << " and " << point <<
- " evaluates to " << baseCase << ".\n";
- }
- else
- {
- Log::Debug << "Base case between " << queryIndex << " and " << point <<
- " already known to be " << baseCase << ".\n";
- }
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+ const double score = frame.score;
+ const size_t parent = frame.parent;
+ const size_t point = node->Point();
+ double baseCase = frame.baseCase;
- // Create the score for the children.
- double childScore = rule.Score(queryIndex, *node, baseCase);
+ // First we recalculate the score of this node to find if we can prune it.
+ if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
- // Now if the childScore is DBL_MAX we can prune all children. In this
- // recursion setup pruning is all or nothing for children.
- if (childScore == DBL_MAX)
- {
- Log::Debug << "Pruning all children.\n";
- numPrunes += node->NumChildren();
- }
- else
- {
- for (size_t i = 0; i < node->NumChildren(); ++i)
+ // If we are a self-child, the base case has already been evaluated.
+ if (point != parent)
+ baseCase = rule.BaseCase(queryIndex, point);
+
+ // Create the score for the children.
+ const double childScore = rule.Score(queryIndex, *node, baseCase);
+
+ // Now if this childScore is DBL_MAX we can prune all children. In this
+ // recursion setup pruning is all or nothing for children.
+ if (childScore == DBL_MAX)
{
+ numPrunes += node->NumChildren();
+ continue;
+ }
+
+ // Don't add the self-leaf.
+ size_t j = 0;
+ if (node->Child(0).NumChildren() == 0)
+ {
+ ++numPrunes;
+ j = 1;
+ }
+
+ for (/* j is already set. */; j < node->NumChildren(); ++j)
+ {
QueueType newFrame;
- newFrame.node = &node->Child(i);
+ newFrame.node = &node->Child(j);
newFrame.score = childScore;
newFrame.baseCase = baseCase;
newFrame.parent = point;
- // Put it into the regular priority queue if it has children.
- if (newFrame.node->NumChildren() > 0)
- {
- Log::Debug << "Push back child " << i << ": point " <<
- newFrame.node->Point() << ", scale " << newFrame.node->Scale()
- << ".\n";
- pointQueue.push(newFrame);
- }
- else if ((newFrame.node->NumChildren() == 0) && (i > 0))
- {
- // We don't add the self-leaf to the leaf queue (it can't possibly
- // help).
- Log::Debug << "Push back child " << i << ": point " <<
- newFrame.node->Point() << ", scale " << newFrame.node->Scale()
- << ".\n";
- leafQueue.push(newFrame);
- }
- else
- {
- Log::Debug << "Prune self-leaf point " << point << ".\n";
- ++numPrunes;
- }
+ mapQueue[newFrame.node->Scale()].push_back(newFrame);
}
}
+
+ // Now clear the memory for this scale; it isn't needed anymore.
+ mapQueue.erase((*rit).first);
}
- // Now look through all the leaves.
- while (!leafQueue.empty())
+ // Now deal with the leaves.
+ for (size_t i = 0; i < mapQueue[INT_MIN].size(); ++i)
{
- QueueType frame = leafQueue.front();
+ const QueueType& frame = mapQueue[INT_MIN].at(i);
CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
const double score = frame.score;
const size_t point = node->Point();
- Log::Debug << "Inspecting leaf " << point << " with score " << score <<
- "\n";
-
- leafQueue.pop();
-
// First, recalculate the score of this node to find if we can prune it.
double actualScore = rule.Rescore(queryIndex, *node, score);
if (actualScore == DBL_MAX)
{
- Log::Debug << "Pruned before base case.\n";
++numPrunes;
continue;
}
- // There are no self-leaves in this queue, so the only thing left to do is
- // evaluate the base case.
+ // There are no self-leaves; evaluate the base case.
const double baseCase = rule.BaseCase(queryIndex, point);
- Log::Debug << "Base case between " << queryIndex << " and " << point <<
- " evaluates to " << baseCase << ".\n";
}
}
More information about the mlpack-svn
mailing list