[mlpack-svn] r13340 - mlpack/trunk/src/mlpack/core/tree/binary_space_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 14:18:21 EDT 2012
Author: rcurtin
Date: 2012-08-06 14:18:21 -0400 (Mon, 06 Aug 2012)
New Revision: 13340
Modified:
mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
Log:
Refactor BinarySpaceTree traversal to use Score() and Rescore() and thereby
avoid unnecessary repeated calculations (so I hope).
Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp 2012-08-06 18:17:55 UTC (rev 13339)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp 2012-08-06 18:18:21 UTC (rev 13340)
@@ -30,13 +30,6 @@
BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
{
- // Check if pruning can occur.
- if (rule.CanPrune(queryNode, referenceNode))
- {
- ++numPrunes;
- return;
- }
-
// If both are leaves, we must evaluate the base case.
if (queryNode.IsLeaf() && referenceNode.IsLeaf())
{
@@ -47,56 +40,175 @@
}
else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
{
- // We have to recurse down the query node.
- if (rule.LeftFirst(referenceNode, queryNode))
- {
- Traverse(*queryNode.Left(), referenceNode);
- Traverse(*queryNode.Right(), referenceNode);
- }
- else
- {
- Traverse(*queryNode.Right(), referenceNode);
- Traverse(*queryNode.Left(), referenceNode);
- }
+ // We have to recurse down the query node. In this case the recursion order
+ // does not matter.
+ Traverse(*queryNode.Left(), referenceNode);
+ Traverse(*queryNode.Right(), referenceNode);
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
- // We have to recurse down the reference node.
- if (rule.LeftFirst(queryNode, referenceNode))
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+
+ if (leftScore < rightScore)
{
+ // Recurse to the left.
Traverse(queryNode, *referenceNode.Left());
- Traverse(queryNode, *referenceNode.Right());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Right());
+ else
+ ++numPrunes;
}
- else
+ else if (rightScore < leftScore)
{
+ // Recurse to the right.
Traverse(queryNode, *referenceNode.Right());
- Traverse(queryNode, *referenceNode.Left());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Left());
+ else
+ ++numPrunes;
}
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(queryNode, *referenceNode.Left());
+
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ }
}
else
{
- // We have to recurse down both.
- if (rule.LeftFirst(*queryNode.Left(), referenceNode))
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+
+ if (leftScore < rightScore)
{
+ // Recurse to the left.
Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
Traverse(*queryNode.Left(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+ else
+ ++numPrunes;
}
else
{
- Traverse(*queryNode.Left(), *referenceNode.Right());
- Traverse(*queryNode.Left(), *referenceNode.Left());
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
}
- // Now recurse to the right query child.
- if (rule.LeftFirst(*queryNode.Right(), referenceNode))
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+
+ if (leftScore < rightScore)
{
+ // Recurse to the left.
Traverse(*queryNode.Right(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
Traverse(*queryNode.Right(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+ else
+ ++numPrunes;
}
else
{
- Traverse(*queryNode.Right(), *referenceNode.Right());
- Traverse(*queryNode.Right(), *referenceNode.Left());
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
}
}
Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp 2012-08-06 18:17:55 UTC (rev 13339)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp 2012-08-06 18:18:21 UTC (rev 13340)
@@ -32,42 +32,63 @@
const size_t queryIndex,
BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
{
- // This is a non-recursive implementation (which should be faster).
+ // If we are a leaf, run the base case as necessary.
+ if (referenceNode.IsLeaf())
+ {
+ for (size_t i = referenceNode.Begin(); i < referenceNode.End(); ++i)
+ rule.BaseCase(queryIndex, i);
+ }
+ else
+ {
+ // If either score is DBL_MAX, we do not recurse into that node.
+ double leftScore = rule.Score(queryIndex, *referenceNode.Left());
+ double rightScore = rule.Score(queryIndex, *referenceNode.Right());
- // The stack of points to look at.
- std::stack<BinarySpaceTree<BoundType, StatisticType, MatType>*> pointStack;
- pointStack.push(&referenceNode);
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left.
+ Traverse(queryIndex, *referenceNode.Left());
- while (!pointStack.empty())
- {
- BinarySpaceTree<BoundType, StatisticType, MatType>* node = pointStack.top();
- pointStack.pop();
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore);
- // Check if we can prune this node.
- if (rule.CanPrune(queryIndex, *node))
+ if (rightScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right.
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
{
- ++numPrunes;
- continue;
- }
+ // Recurse to the right.
+ Traverse(queryIndex, *referenceNode.Right());
- // If we are a leaf, run the base case as necessary.
- if (node->IsLeaf())
- {
- for (size_t i = node->Begin(); i < node->End(); ++i)
- rule.BaseCase(queryIndex, i);
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left.
+ else
+ ++numPrunes;
}
- else
+ else // leftScore is equal to rightScore.
{
- // Otherwise recurse by distance.
- if (rule.LeftFirst(queryIndex, *node))
+ if (leftScore == DBL_MAX)
{
- pointStack.push(node->Right());
- pointStack.push(node->Left());
+ numPrunes += 2; // Pruned both left and right.
}
else
{
- pointStack.push(node->Left());
- pointStack.push(node->Right());
+ // Choose the left first.
+ Traverse(queryIndex, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryIndex, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Right());
+ else
+ ++numPrunes;
}
}
}
More information about the mlpack-svn
mailing list