[mlpack-git] master: Make the search *actually* breadth-first. Significant speedups result! (4003537)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:46 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 40035373719652e9e01f84c5d172f9b89c201c2f
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 26 18:49:39 2015 -0500
Make the search *actually* breadth-first. Significant speedups result!
>---------------------------------------------------------------
40035373719652e9e01f84c5d172f9b89c201c2f
.../breadth_first_dual_tree_traverser_impl.hpp | 432 ++++-----------------
1 file changed, 86 insertions(+), 346 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index cc0fa13..5a32bfc 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -32,6 +32,27 @@ BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
numBaseCases(0)
{ /* Nothing to do. */ }
+template<typename TreeType, typename TraversalInfoType>
+struct QueueFrame
+{
+ TreeType* queryNode;
+ TreeType* referenceNode;
+ size_t queryDepth;
+ double score;
+ TraversalInfoType traversalInfo;
+};
+
+template<typename TreeType, typename TraversalInfoType>
+bool operator<(const QueueFrame<TreeType, TraversalInfoType>& a,
+ const QueueFrame<TreeType, TraversalInfoType>& b)
+{
+ if (a.queryDepth > b.queryDepth)
+ return true;
+ else if ((a.queryDepth == b.queryDepth) && (a.score > b.score))
+ return true;
+ return false;
+}
+
template<typename BoundType,
typename StatisticType,
typename MatType,
@@ -57,24 +78,38 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
if (rootScore == DBL_MAX)
return; // This probably means something is wrong.
- std::queue<TreeType*> queryList;
- std::queue<TreeType*> referenceList;
- std::queue<typename RuleType::TraversalInfoType> traversalInfos;
- queryList.push(&queryRoot);
- referenceList.push(&referenceRoot);
- traversalInfos.push(rule.TraversalInfo());
+ typedef QueueFrame<TreeType, typename RuleType::TraversalInfoType>
+ QueueFrameType;
+ std::priority_queue<QueueFrameType> queue;
- while (!queryList.empty())
- {
- TreeType& queryNode = *queryList.front();
- TreeType& referenceNode = *referenceList.front();
- typename RuleType::TraversalInfoType ti = traversalInfos.front();
+ QueueFrameType rootFrame;
+ rootFrame.queryNode = &queryRoot;
+ rootFrame.referenceNode = &referenceRoot;
+ rootFrame.queryDepth = 0;
+ rootFrame.score = 0.0;
+ rootFrame.traversalInfo = rule.TraversalInfo();
- queryList.pop();
- referenceList.pop();
- traversalInfos.pop();
+ queue.push(rootFrame);
+ while (!queue.empty())
+ {
+ QueueFrameType currentFrame = queue.top();
+ queue.pop();
+
+ TreeType& queryNode = *currentFrame.queryNode;
+ TreeType& referenceNode = *currentFrame.referenceNode;
+ typename RuleType::TraversalInfoType ti = currentFrame.traversalInfo;
rule.TraversalInfo() = ti;
+ const size_t queryDepth = currentFrame.queryDepth;
+
+ double score = rule.Score(queryNode, referenceNode);
+ ++numScores;
+
+ if (score == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
// If both are leaves, we must evaluate the base case.
if (queryNode.IsLeaf() && referenceNode.IsLeaf())
@@ -83,14 +118,15 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
{
// See if we need to investigate this point (this function should be
- // implemented for the single-tree recursion too). Restore the traversal
- // information first.
+ // implemented for the single-tree recursion too). Restore the
+ // traversal information first.
// const double childScore = rule.Score(query, referenceNode);
// if (childScore == DBL_MAX)
// continue; // We can't improve this particular point.
- for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+ for (size_t ref = referenceNode.Begin(); ref < referenceNode.End();
+ ++ref)
rule.BaseCase(query, ref);
numBaseCases += referenceNode.Count();
@@ -98,345 +134,49 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
}
else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
{
- // We have to recurse down the query node. In this case the recursion order
- // does not matter.
- const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
- ++numScores;
-
- if (leftScore != DBL_MAX)
- {
- queryList.push(queryNode.Left());
- referenceList.push(&referenceNode);
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push1 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- {
- ++numPrunes;
- }
-
- // Before recursing, we have to set the traversal information correctly.
- rule.TraversalInfo() = ti;
- const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
- ++numScores;
-
- if (rightScore != DBL_MAX)
- {
- queryList.push(queryNode.Right());
- referenceList.push(&referenceNode);
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push2 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
+ // We have to recurse down the query node.
+ QueueFrameType fl = { queryNode.Left(), &referenceNode, queryDepth + 1,
+ score, rule.TraversalInfo() };
+ queue.push(fl);
+
+ QueueFrameType fr = { queryNode.Right(), &referenceNode, queryDepth + 1,
+ score, ti };
+ queue.push(fr);
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
// We have to recurse down the reference node. In this case the recursion
// order does matter. Before recursing, though, we have to set the
// traversal information correctly.
- double leftScore = rule.Score(queryNode, *referenceNode.Left());
- typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = ti;
- double rightScore = rule.Score(queryNode, *referenceNode.Right());
- numScores += 2;
-
- if (leftScore < rightScore)
- {
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push3 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push4 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push5 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
-
- if (leftScore != DBL_MAX)
- {
- // Restore the left traversal info.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push6 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- else // leftScore is equal to rightScore.
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first. Restore the left traversal info. Store the
- // right traversal info.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push7 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- queryList.push(&queryNode);
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push8 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- }
- }
- else
- {
- // We have to recurse down both query and reference nodes. Because the
- // query descent order does not matter, we will go to the left query child
- // first. Before recursing, we have to set the traversal information
- // correctly.
- double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
- typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = ti;
- double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
- typename RuleType::TraversalInfoType rightInfo;
- numScores += 2;
-
- if (leftScore < rightScore)
- {
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push9 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
- rightScore);
+ QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
+ score, rule.TraversalInfo() };
+ queue.push(fl);
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push10 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push11 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
- leftScore);
-
- if (leftScore != DBL_MAX)
- {
- // Restore the left traversal info.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push12 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
+ QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
+ score, ti };
+ queue.push(fr);
}
else
{
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first. Restore the left traversal info and store the
- // right traversal info.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push13 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal information.
- queryList.push(queryNode.Left());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push14 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- }
-
- // Restore the main traversal information.
- rule.TraversalInfo() = ti;
-
- // Now recurse down the right query node.
- leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
- leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = ti;
- rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
- numScores += 2;
-
- if (leftScore < rightScore)
- {
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push15 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push16 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push17 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
- leftScore);
-
- if (leftScore != DBL_MAX)
- {
- // Restore the left traversal info.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push18 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- else
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first. Restore the left traversal info. Store the
- // right traversal info.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Left());
- traversalInfos.push(leftInfo);
-// Log::Debug << "Push19 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- queryList.push(queryNode.Right());
- referenceList.push(referenceNode.Right());
- traversalInfos.push(rule.TraversalInfo());
-// Log::Debug << "Push20 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-// << referenceList.back()->Count() << "\n";
- }
- else
- ++numPrunes;
- }
- }
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
+ QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ queue.push(fll);
+
+ QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ queue.push(flr);
+
+ QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ queue.push(frl);
+
+ QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ queue.push(frr);
}
}
}
More information about the mlpack-git
mailing list