[mlpack-git] master: Make traversal depth-first in the queries. This results in less memory usage, and some amount of speedup. (43ab267)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:44 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 43ab2675140761ffd7450e6e2056d4d6cd0e87bc
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Feb 27 17:09:30 2015 -0500

    Make traversal depth-first in the queries. This results in less memory usage, and some amount of speedup.


>---------------------------------------------------------------

43ab2675140761ffd7450e6e2056d4d6cd0e87bc
 .../breadth_first_dual_tree_traverser.hpp          | 16 +++++
 .../breadth_first_dual_tree_traverser_impl.hpp     | 69 ++++++++++++----------
 2 files changed, 55 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
index 8a22a70..4af0372 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
@@ -11,12 +11,23 @@
 #define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
 
 #include <mlpack/core.hpp>
+#include <queue>
 
 #include "binary_space_tree.hpp"
 
 namespace mlpack {
 namespace tree {
 
+template<typename TreeType, typename TraversalInfoType>
+struct QueueFrame
+{
+  TreeType* queryNode;
+  TreeType* referenceNode;
+  size_t queryDepth;
+  double score;
+  TraversalInfoType traversalInfo;
+};
+
 template<typename BoundType,
          typename StatisticType,
          typename MatType,
@@ -31,6 +42,9 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
    */
   BreadthFirstDualTreeTraverser(RuleType& rule);
 
+  typedef QueueFrame<BinarySpaceTree, typename RuleType::TraversalInfoType>
+      QueueFrameType;
+
   /**
    * Traverse the two trees.  This does not reset the number of prunes.
    *
@@ -40,6 +54,8 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
    */
   void Traverse(BinarySpaceTree& queryNode,
                 BinarySpaceTree& referenceNode);
+  void Traverse(BinarySpaceTree& queryNode,
+                std::priority_queue<QueueFrameType>& referenceQueue);
 
   //! Get the number of prunes.
   size_t NumPrunes() const { return numPrunes; }
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index 5a32bfc..f48bf95 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -12,8 +12,6 @@
 // In case it hasn't been included yet.
 #include "breadth_first_dual_tree_traverser.hpp"
 
-#include <queue>
-
 namespace mlpack {
 namespace tree {
 
@@ -33,16 +31,6 @@ BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
 { /* Nothing to do. */ }
 
 template<typename TreeType, typename TraversalInfoType>
-struct QueueFrame
-{
-  TreeType* queryNode;
-  TreeType* referenceNode;
-  size_t queryDepth;
-  double score;
-  TraversalInfoType traversalInfo;
-};
-
-template<typename TreeType, typename TraversalInfoType>
 bool operator<(const QueueFrame<TreeType, TraversalInfoType>& a,
                const QueueFrame<TreeType, TraversalInfoType>& b)
 {
@@ -70,16 +58,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
   // Store the current traversal info.
   traversalInfo = rule.TraversalInfo();
 
-  typedef BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>
-      TreeType;
-
   // Must score the root combination.
   const double rootScore = rule.Score(queryRoot, referenceRoot);
   if (rootScore == DBL_MAX)
     return; // This probably means something is wrong.
 
-  typedef QueueFrame<TreeType, typename RuleType::TraversalInfoType>
-      QueueFrameType;
   std::priority_queue<QueueFrameType> queue;
 
   QueueFrameType rootFrame;
@@ -91,13 +74,32 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
 
   queue.push(rootFrame);
 
-  while (!queue.empty())
+  // Start the traversal.
+  Traverse(queryRoot, queue);
+}
+
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::Traverse(
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryNode,
+    std::priority_queue<QueueFrameType>& referenceQueue)
+{
+  // Store queues for the children.  We will recurse into the children once our
+  // queue is empty.
+  std::priority_queue<QueueFrameType> leftChildQueue;
+  std::priority_queue<QueueFrameType> rightChildQueue;
+
+  while (!referenceQueue.empty())
   {
-    QueueFrameType currentFrame = queue.top();
-    queue.pop();
+    QueueFrameType currentFrame = referenceQueue.top();
+    referenceQueue.pop();
 
-    TreeType& queryNode = *currentFrame.queryNode;
-    TreeType& referenceNode = *currentFrame.referenceNode;
+    BinarySpaceTree& queryNode = *currentFrame.queryNode;
+    BinarySpaceTree& referenceNode = *currentFrame.referenceNode;
     typename RuleType::TraversalInfoType ti = currentFrame.traversalInfo;
     rule.TraversalInfo() = ti;
     const size_t queryDepth = currentFrame.queryDepth;
@@ -137,11 +139,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
       // We have to recurse down the query node.
       QueueFrameType fl = { queryNode.Left(), &referenceNode, queryDepth + 1,
           score, rule.TraversalInfo() };
-      queue.push(fl);
+      leftChildQueue.push(fl);
 
       QueueFrameType fr = { queryNode.Right(), &referenceNode, queryDepth + 1,
           score, ti };
-      queue.push(fr);
+      rightChildQueue.push(fr);
     }
     else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
     {
@@ -150,11 +152,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
       // traversal information correctly.
       QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
           score, rule.TraversalInfo() };
-      queue.push(fl);
+      referenceQueue.push(fl);
 
       QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
           score, ti };
-      queue.push(fr);
+      referenceQueue.push(fr);
     }
     else
     {
@@ -164,21 +166,28 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
       // correctly.
       QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
           queryDepth + 1, score, rule.TraversalInfo() };
-      queue.push(fll);
+      leftChildQueue.push(fll);
 
       QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
           queryDepth + 1, score, rule.TraversalInfo() };
-      queue.push(flr);
+      leftChildQueue.push(flr);
 
       QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
           queryDepth + 1, score, rule.TraversalInfo() };
-      queue.push(frl);
+      rightChildQueue.push(frl);
 
       QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
           queryDepth + 1, score, rule.TraversalInfo() };
-      queue.push(frr);
+      rightChildQueue.push(frr);
     }
   }
+
+  // Now, recurse into the left and right children queues.  The order doesn't
+  // matter.
+  if (leftChildQueue.size() > 0)
+    Traverse(*queryNode.Left(), leftChildQueue);
+  if (rightChildQueue.size() > 0)
+    Traverse(*queryNode.Right(), rightChildQueue);
 }
 
 }; // namespace tree



More information about the mlpack-git mailing list