[mlpack-git] master: Add a semi-hackish breadth-first traverser. The tree abstractions will need to change to support arbitrary traverser types (probably by adding a template parameter) but for now this works to make DualTreeKMeans work. (019ecbc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:02:34 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 019ecbc045b09c000504eb633585109053c63bae
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Nov 5 21:40:37 2014 +0000

    Add a semi-hackish breadth-first traverser.  The tree abstractions will need to
    change to support arbitrary traverser types (probably by adding a template
    parameter) but for now this works to make DualTreeKMeans work.


>---------------------------------------------------------------

019ecbc045b09c000504eb633585109053c63bae
 .../tree/binary_space_tree/binary_space_tree.hpp   |   3 +
 ...r.hpp => breadth_first_dual_tree_traverser.hpp} |  22 +-
 .../breadth_first_dual_tree_traverser_impl.hpp     | 442 +++++++++++++++++++++
 3 files changed, 456 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index eadd300..db1ece7 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -87,6 +87,9 @@ class BinarySpaceTree
   template<typename RuleType>
   class DualTreeTraverser;
 
+  template<typename RuleType>
+  class BreadthFirstDualTreeTraverser;
+
   /**
    * Construct this as the root node of a binary space tree using the given
    * dataset.  This will modify the ordering of the points in the dataset!
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
similarity index 75%
copy from src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
copy to src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
index 7cd1871..8a22a70 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
@@ -1,14 +1,14 @@
 /**
- * @file dual_tree_traverser.hpp
+ * @file breadth_first_dual_tree_traverser.hpp
  * @author Ryan Curtin
  *
- * Defines the DualTreeTraverser for the BinarySpaceTree tree type.  This is a
- * nested class of BinarySpaceTree which traverses two trees in a depth-first
- * manner with a given set of rules which indicate the branches which can be
- * pruned and the order in which to recurse.
+ * Defines the BreadthFirstDualTreeTraverser for the BinarySpaceTree tree type.
+ * This is a nested class of BinarySpaceTree which traverses two trees in a
+ * breadth-first manner with a given set of rules which indicate the branches
+ * which can be pruned and the order in which to recurse.
  */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
 
 #include <mlpack/core.hpp>
 
@@ -23,13 +23,13 @@ template<typename BoundType,
          typename SplitType>
 template<typename RuleType>
 class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
-    DualTreeTraverser
+    BreadthFirstDualTreeTraverser
 {
  public:
   /**
    * Instantiate the dual-tree traverser with the given rule set.
    */
-  DualTreeTraverser(RuleType& rule);
+  BreadthFirstDualTreeTraverser(RuleType& rule);
 
   /**
    * Traverse the two trees.  This does not reset the number of prunes.
@@ -86,7 +86,7 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 }; // namespace mlpack
 
 // Include implementation.
-#include "dual_tree_traverser_impl.hpp"
+#include "breadth_first_dual_tree_traverser_impl.hpp"
 
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
 
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..bd81df2
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -0,0 +1,442 @@
+/**
+ * @file breadth_first_dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the BreadthFirstDualTreeTraverser for BinarySpaceTree.
+ * This is a way to perform a dual-tree traversal of two trees.  The trees must
+ * be the same type.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "breadth_first_dual_tree_traverser.hpp"
+
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+template<typename RuleType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
+    RuleType& rule) :
+    rule(rule),
+    numPrunes(0),
+    numVisited(0),
+    numScores(0),
+    numBaseCases(0)
+{ /* Nothing to do. */ }
+
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::Traverse(
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryRoot,
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+        referenceRoot)
+{
+  // Increment the visit counter.
+  ++numVisited;
+
+  // Store the current traversal info.
+  traversalInfo = rule.TraversalInfo();
+
+  typedef BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>
+      TreeType;
+
+  std::queue<TreeType*> queryList;
+  std::queue<TreeType*> referenceList;
+  std::queue<typename RuleType::TraversalInfoType> traversalInfos;
+  queryList.push(&queryRoot);
+  referenceList.push(&referenceRoot);
+  traversalInfos.push(rule.TraversalInfo());
+
+  while (!queryList.empty())
+  {
+    TreeType& queryNode = *queryList.front();
+    TreeType& referenceNode = *referenceList.front();
+    typename RuleType::TraversalInfoType ti = traversalInfos.front();
+
+    queryList.pop();
+    referenceList.pop();
+    traversalInfos.pop();
+
+    rule.TraversalInfo() = ti;
+
+    // If both are leaves, we must evaluate the base case.
+    if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+    {
+      // Loop through each of the points in each node.
+      for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
+      {
+        // See if we need to investigate this point (this function should be
+        // implemented for the single-tree recursion too).  Restore the traversal
+        // information first.
+//        const double childScore = rule.Score(query, referenceNode);
+
+//        if (childScore == DBL_MAX)
+//          continue; // We can't improve this particular point.
+
+        for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+          rule.BaseCase(query, ref);
+
+        numBaseCases += referenceNode.Count();
+      }
+    }
+    else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
+    {
+      // We have to recurse down the query node.  In this case the recursion order
+      // does not matter.
+      const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+      ++numScores;
+
+      if (leftScore != DBL_MAX)
+      {
+        queryList.push(queryNode.Left());
+        referenceList.push(&referenceNode);
+        traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push1 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+      {
+        ++numPrunes;
+      }
+
+      // Before recursing, we have to set the traversal information correctly.
+      rule.TraversalInfo() = ti;
+      const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+      ++numScores;
+
+      if (rightScore != DBL_MAX)
+      {
+        queryList.push(queryNode.Right());
+        referenceList.push(&referenceNode);
+        traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push2 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
+    {
+      // We have to recurse down the reference node.  In this case the recursion
+      // order does matter.  Before recursing, though, we have to set the
+      // traversal information correctly.
+      double leftScore = rule.Score(queryNode, *referenceNode.Left());
+      typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = ti;
+      double rightScore = rule.Score(queryNode, *referenceNode.Right());
+      numScores += 2;
+
+      if (leftScore < rightScore)
+      {
+        // Recurse to the left.  Restore the left traversal info.  Store the right
+        // traversal info.
+        queryList.push(&queryNode);
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push3 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          queryList.push(&queryNode);
+          referenceList.push(referenceNode.Right());
+          traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push4 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+        }
+        else
+          ++numPrunes;
+      }
+      else if (rightScore < leftScore)
+    {
+      // Recurse to the right.
+      queryList.push(&queryNode);
+      referenceList.push(referenceNode.Right());
+      traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push5 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+      if (leftScore != DBL_MAX)
+      {
+        // Restore the left traversal info.
+        queryList.push(&queryNode);
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push6 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else // leftScore is equal to rightScore.
+    {
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.  Restore the left traversal info.  Store the
+        // right traversal info.
+        queryList.push(&queryNode);
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push7 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+        rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          queryList.push(&queryNode);
+          referenceList.push(referenceNode.Right());
+          traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push8 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+        }
+        else
+          ++numPrunes;
+      }
+    }
+  }
+  else
+  {
+    // We have to recurse down both query and reference nodes.  Because the
+    // query descent order does not matter, we will go to the left query child
+    // first.  Before recursing, we have to set the traversal information
+    // correctly.
+    double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+    rule.TraversalInfo() = ti;
+    double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+    typename RuleType::TraversalInfoType rightInfo;
+    numScores += 2;
+
+    if (leftScore < rightScore)
+    {
+      // Recurse to the left.  Restore the left traversal info.  Store the right
+      // traversal info.
+      queryList.push(queryNode.Left());
+      referenceList.push(referenceNode.Left());
+      traversalInfos.push(leftInfo);
+//        Log::Debug << "Push9 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+          rightScore);
+
+      if (rightScore != DBL_MAX)
+      {
+        // Restore the right traversal info.
+        queryList.push(queryNode.Left());
+        referenceList.push(referenceNode.Right());
+        traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push10 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else if (rightScore < leftScore)
+    {
+      // Recurse to the right.
+      queryList.push(queryNode.Left());
+      referenceList.push(referenceNode.Right());
+      traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push11 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+          leftScore);
+
+      if (leftScore != DBL_MAX)
+      {
+        // Restore the left traversal info.
+        queryList.push(queryNode.Left());
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push12 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else
+    {
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.  Restore the left traversal info and store the
+        // right traversal info.
+        queryList.push(queryNode.Left());
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push13 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal information.
+          queryList.push(queryNode.Left());
+          referenceList.push(referenceNode.Right());
+          traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push14 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+        }
+        else
+          ++numPrunes;
+      }
+    }
+
+    // Restore the main traversal information.
+    rule.TraversalInfo() = ti;
+
+    // Now recurse down the right query node.
+    leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+    leftInfo = rule.TraversalInfo();
+    rule.TraversalInfo() = ti;
+    rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+    numScores += 2;
+
+    if (leftScore < rightScore)
+    {
+      // Recurse to the left.  Restore the left traversal info.  Store the right
+      // traversal info.
+      queryList.push(queryNode.Right());
+      referenceList.push(referenceNode.Left());
+      traversalInfos.push(leftInfo);
+//        Log::Debug << "Push15 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+          rightScore);
+
+      if (rightScore != DBL_MAX)
+      {
+        // Restore the right traversal info.
+        queryList.push(queryNode.Right());
+        referenceList.push(referenceNode.Right());
+        traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push16 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else if (rightScore < leftScore)
+    {
+      // Recurse to the right.
+      queryList.push(queryNode.Right());
+      referenceList.push(referenceNode.Right());
+      traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push17 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+          leftScore);
+
+      if (leftScore != DBL_MAX)
+      {
+        // Restore the left traversal info.
+        queryList.push(queryNode.Right());
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push18 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+      }
+      else
+        ++numPrunes;
+    }
+    else
+    {
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.  Restore the left traversal info.  Store the
+        // right traversal info.
+        queryList.push(queryNode.Right());
+        referenceList.push(referenceNode.Left());
+        traversalInfos.push(leftInfo);
+//        Log::Debug << "Push19 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          queryList.push(queryNode.Right());
+          referenceList.push(referenceNode.Right());
+          traversalInfos.push(rule.TraversalInfo());
+//        Log::Debug << "Push20 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+//    << referenceList.back()->Count() << "\n";
+        }
+        else
+          ++numPrunes;
+      }
+    }
+    }
+  }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP



More information about the mlpack-git mailing list