[mlpack-git] master: Make the search *actually* breadth-first. Significant speedups result! (4003537)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:46 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 40035373719652e9e01f84c5d172f9b89c201c2f
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 26 18:49:39 2015 -0500

    Make the search *actually* breadth-first. Significant speedups result!


>---------------------------------------------------------------

40035373719652e9e01f84c5d172f9b89c201c2f
 .../breadth_first_dual_tree_traverser_impl.hpp     | 432 ++++-----------------
 1 file changed, 86 insertions(+), 346 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index cc0fa13..5a32bfc 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -32,6 +32,27 @@ BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
     numBaseCases(0)
 { /* Nothing to do. */ }
 
+template<typename TreeType, typename TraversalInfoType>
+struct QueueFrame
+{
+  TreeType* queryNode;
+  TreeType* referenceNode;
+  size_t queryDepth;
+  double score;
+  TraversalInfoType traversalInfo;
+};
+
+template<typename TreeType, typename TraversalInfoType>
+bool operator<(const QueueFrame<TreeType, TraversalInfoType>& a,
+               const QueueFrame<TreeType, TraversalInfoType>& b)
+{
+  if (a.queryDepth > b.queryDepth)
+    return true;
+  else if ((a.queryDepth == b.queryDepth) && (a.score > b.score))
+    return true;
+  return false;
+}
+
 template<typename BoundType,
          typename StatisticType,
          typename MatType,
@@ -57,24 +78,38 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
   if (rootScore == DBL_MAX)
     return; // This probably means something is wrong.
 
-  std::queue<TreeType*> queryList;
-  std::queue<TreeType*> referenceList;
-  std::queue<typename RuleType::TraversalInfoType> traversalInfos;
-  queryList.push(&queryRoot);
-  referenceList.push(&referenceRoot);
-  traversalInfos.push(rule.TraversalInfo());
+  typedef QueueFrame<TreeType, typename RuleType::TraversalInfoType>
+      QueueFrameType;
+  std::priority_queue<QueueFrameType> queue;
 
-  while (!queryList.empty())
-  {
-    TreeType& queryNode = *queryList.front();
-    TreeType& referenceNode = *referenceList.front();
-    typename RuleType::TraversalInfoType ti = traversalInfos.front();
+  QueueFrameType rootFrame;
+  rootFrame.queryNode = &queryRoot;
+  rootFrame.referenceNode = &referenceRoot;
+  rootFrame.queryDepth = 0;
+  rootFrame.score = 0.0;
+  rootFrame.traversalInfo = rule.TraversalInfo();
 
-    queryList.pop();
-    referenceList.pop();
-    traversalInfos.pop();
+  queue.push(rootFrame);
 
+  while (!queue.empty())
+  {
+    QueueFrameType currentFrame = queue.top();
+    queue.pop();
+
+    TreeType& queryNode = *currentFrame.queryNode;
+    TreeType& referenceNode = *currentFrame.referenceNode;
+    typename RuleType::TraversalInfoType ti = currentFrame.traversalInfo;
     rule.TraversalInfo() = ti;
+    const size_t queryDepth = currentFrame.queryDepth;
+
+    double score = rule.Score(queryNode, referenceNode);
+    ++numScores;
+
+    if (score == DBL_MAX)
+    {
+      ++numPrunes;
+      continue;
+    }
 
     // If both are leaves, we must evaluate the base case.
     if (queryNode.IsLeaf() && referenceNode.IsLeaf())
@@ -83,14 +118,15 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
       for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
       {
         // See if we need to investigate this point (this function should be
-        // implemented for the single-tree recursion too).  Restore the traversal
-        // information first.
+        // implemented for the single-tree recursion too).  Restore the
+        // traversal information first.
 //        const double childScore = rule.Score(query, referenceNode);
 
 //        if (childScore == DBL_MAX)
 //          continue; // We can't improve this particular point.
 
-        for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+        for (size_t ref = referenceNode.Begin(); ref < referenceNode.End();
+            ++ref)
           rule.BaseCase(query, ref);
 
         numBaseCases += referenceNode.Count();
@@ -98,345 +134,49 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
     }
     else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
     {
-      // We have to recurse down the query node.  In this case the recursion order
-      // does not matter.
-      const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
-      ++numScores;
-
-      if (leftScore != DBL_MAX)
-      {
-        queryList.push(queryNode.Left());
-        referenceList.push(&referenceNode);
-        traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push1 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-      {
-        ++numPrunes;
-      }
-
-      // Before recursing, we have to set the traversal information correctly.
-      rule.TraversalInfo() = ti;
-      const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
-      ++numScores;
-
-      if (rightScore != DBL_MAX)
-      {
-        queryList.push(queryNode.Right());
-        referenceList.push(&referenceNode);
-        traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push2 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
+      // We have to recurse down the query node.
+      QueueFrameType fl = { queryNode.Left(), &referenceNode, queryDepth + 1,
+          score, rule.TraversalInfo() };
+      queue.push(fl);
+
+      QueueFrameType fr = { queryNode.Right(), &referenceNode, queryDepth + 1,
+          score, ti };
+      queue.push(fr);
     }
     else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
     {
       // We have to recurse down the reference node.  In this case the recursion
       // order does matter.  Before recursing, though, we have to set the
       // traversal information correctly.
-      double leftScore = rule.Score(queryNode, *referenceNode.Left());
-      typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
-      rule.TraversalInfo() = ti;
-      double rightScore = rule.Score(queryNode, *referenceNode.Right());
-      numScores += 2;
-
-      if (leftScore < rightScore)
-      {
-        // Recurse to the left.  Restore the left traversal info.  Store the right
-        // traversal info.
-        queryList.push(&queryNode);
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push3 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-        // Is it still valid to recurse to the right?
-        rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
-
-        if (rightScore != DBL_MAX)
-        {
-          // Restore the right traversal info.
-          queryList.push(&queryNode);
-          referenceList.push(referenceNode.Right());
-          traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push4 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-        }
-        else
-          ++numPrunes;
-      }
-      else if (rightScore < leftScore)
-    {
-      // Recurse to the right.
-      queryList.push(&queryNode);
-      referenceList.push(referenceNode.Right());
-      traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push5 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
-
-      if (leftScore != DBL_MAX)
-      {
-        // Restore the left traversal info.
-        queryList.push(&queryNode);
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push6 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
-    }
-    else // leftScore is equal to rightScore.
-    {
-      if (leftScore == DBL_MAX)
-      {
-        numPrunes += 2;
-      }
-      else
-      {
-        // Choose the left first.  Restore the left traversal info.  Store the
-        // right traversal info.
-        queryList.push(&queryNode);
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push7 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-        rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
-            rightScore);
-
-        if (rightScore != DBL_MAX)
-        {
-          // Restore the right traversal info.
-          queryList.push(&queryNode);
-          referenceList.push(referenceNode.Right());
-          traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push8 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-        }
-        else
-          ++numPrunes;
-      }
-    }
-  }
-  else
-  {
-    // We have to recurse down both query and reference nodes.  Because the
-    // query descent order does not matter, we will go to the left query child
-    // first.  Before recursing, we have to set the traversal information
-    // correctly.
-    double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
-    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
-    rule.TraversalInfo() = ti;
-    double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
-    typename RuleType::TraversalInfoType rightInfo;
-    numScores += 2;
-
-    if (leftScore < rightScore)
-    {
-      // Recurse to the left.  Restore the left traversal info.  Store the right
-      // traversal info.
-      queryList.push(queryNode.Left());
-      referenceList.push(referenceNode.Left());
-      traversalInfos.push(leftInfo);
-//        Log::Debug << "Push9 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-      // Is it still valid to recurse to the right?
-      rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
-          rightScore);
+      QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
+          score, rule.TraversalInfo() };
+      queue.push(fl);
 
-      if (rightScore != DBL_MAX)
-      {
-        // Restore the right traversal info.
-        queryList.push(queryNode.Left());
-        referenceList.push(referenceNode.Right());
-        traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push10 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
-    }
-    else if (rightScore < leftScore)
-    {
-      // Recurse to the right.
-      queryList.push(queryNode.Left());
-      referenceList.push(referenceNode.Right());
-      traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push11 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
-          leftScore);
-
-      if (leftScore != DBL_MAX)
-      {
-        // Restore the left traversal info.
-        queryList.push(queryNode.Left());
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push12 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
+      QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
+          score, ti };
+      queue.push(fr);
     }
     else
     {
-      if (leftScore == DBL_MAX)
-      {
-        numPrunes += 2;
-      }
-      else
-      {
-        // Choose the left first.  Restore the left traversal info and store the
-        // right traversal info.
-        queryList.push(queryNode.Left());
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push13 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-        // Is it still valid to recurse to the right?
-        rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
-            rightScore);
-
-        if (rightScore != DBL_MAX)
-        {
-          // Restore the right traversal information.
-          queryList.push(queryNode.Left());
-          referenceList.push(referenceNode.Right());
-          traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push14 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-        }
-        else
-          ++numPrunes;
-      }
-    }
-
-    // Restore the main traversal information.
-    rule.TraversalInfo() = ti;
-
-    // Now recurse down the right query node.
-    leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
-    leftInfo = rule.TraversalInfo();
-    rule.TraversalInfo() = ti;
-    rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
-    numScores += 2;
-
-    if (leftScore < rightScore)
-    {
-      // Recurse to the left.  Restore the left traversal info.  Store the right
-      // traversal info.
-      queryList.push(queryNode.Right());
-      referenceList.push(referenceNode.Left());
-      traversalInfos.push(leftInfo);
-//        Log::Debug << "Push15 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-      // Is it still valid to recurse to the right?
-      rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
-          rightScore);
-
-      if (rightScore != DBL_MAX)
-      {
-        // Restore the right traversal info.
-        queryList.push(queryNode.Right());
-        referenceList.push(referenceNode.Right());
-        traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push16 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
-    }
-    else if (rightScore < leftScore)
-    {
-      // Recurse to the right.
-      queryList.push(queryNode.Right());
-      referenceList.push(referenceNode.Right());
-      traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push17 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
-          leftScore);
-
-      if (leftScore != DBL_MAX)
-      {
-        // Restore the left traversal info.
-        queryList.push(queryNode.Right());
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push18 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-      }
-      else
-        ++numPrunes;
-    }
-    else
-    {
-      if (leftScore == DBL_MAX)
-      {
-        numPrunes += 2;
-      }
-      else
-      {
-        // Choose the left first.  Restore the left traversal info.  Store the
-        // right traversal info.
-        queryList.push(queryNode.Right());
-        referenceList.push(referenceNode.Left());
-        traversalInfos.push(leftInfo);
-//        Log::Debug << "Push19 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-
-        // Is it still valid to recurse to the right?
-        rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
-            rightScore);
-
-        if (rightScore != DBL_MAX)
-        {
-          // Restore the right traversal info.
-          queryList.push(queryNode.Right());
-          referenceList.push(referenceNode.Right());
-          traversalInfos.push(rule.TraversalInfo());
-//        Log::Debug << "Push20 " << queryList.back()->Begin() << ", " <<
-//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
-//    << referenceList.back()->Count() << "\n";
-        }
-        else
-          ++numPrunes;
-      }
-    }
+      // We have to recurse down both query and reference nodes.  Because the
+      // query descent order does not matter, we will go to the left query child
+      // first.  Before recursing, we have to set the traversal information
+      // correctly.
+      QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
+          queryDepth + 1, score, rule.TraversalInfo() };
+      queue.push(fll);
+
+      QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
+          queryDepth + 1, score, rule.TraversalInfo() };
+      queue.push(flr);
+
+      QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
+          queryDepth + 1, score, rule.TraversalInfo() };
+      queue.push(frl);
+
+      QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
+          queryDepth + 1, score, rule.TraversalInfo() };
+      queue.push(frr);
     }
   }
 }



More information about the mlpack-git mailing list