[mlpack-git] master: dual tree traverser bug fixes. (b88ae3b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:59:02 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit b88ae3b5fd041a1bbe05732d17a5e70be5d9097c
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Wed Aug 20 17:23:37 2014 +0000

    dual tree traverser bug fixes.


>---------------------------------------------------------------

b88ae3b5fd041a1bbe05732d17a5e70be5d9097c
 .../core/tree/rectangle_tree/dual_tree_traverser.hpp     |  1 +
 .../tree/rectangle_tree/dual_tree_traverser_impl.hpp     | 16 +++++++---------
 2 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 6610da5..50f971d 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -67,6 +67,7 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
   public:
     RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
     double score;
+    typename RuleType::TraversalInfoType travInfo;
   };
 
   static bool nodeComparator(const NodeAndScore& obj1,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 7ac9b20..8bdad30 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -76,13 +76,7 @@ DualTreeTraverser<RuleType>::Traverse(
   else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
     // We only need to traverse down the query node.  Order doesn't matter here.
-    ++numScores;
-    if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
-      Traverse(queryNode.Child(0), referenceNode);
-    else
-      numPrunes++;
-    
-    for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+    for(size_t i = 0; i < queryNode.NumChildren(); ++i)
     {
       // Before recursing, we have to set the traversal information correctly.
       rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@ DualTreeTraverser<RuleType>::Traverse(
     std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
     for(int i = 0; i < referenceNode.NumChildren(); i++)
     {
+      rule.TraversalInfo() = traversalInfo;
       nodesAndScores[i].node = referenceNode.Children()[i];
       nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+      nodesAndScores[i].travInfo = rule.TraversalInfo();
     }
     std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
     numScores += nodesAndScores.size();
     
     for(int i = 0; i < nodesAndScores.size(); i++)
     {
+      rule.TraversalInfo() = nodesAndScores[i].travInfo;
       if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-        rule.TraversalInfo() = traversalInfo;
         Traverse(queryNode, *(nodesAndScores[i].node));
       } else {
         numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@ DualTreeTraverser<RuleType>::Traverse(
       std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
       for(int i = 0; i < referenceNode.NumChildren(); i++)
       {
+        rule.TraversalInfo() = traversalInfo;
         nodesAndScores[i].node = referenceNode.Children()[i];
         nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+        nodesAndScores[i].travInfo = rule.TraversalInfo();
       }
       std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
       numScores += nodesAndScores.size();
     
       for(int i = 0; i < nodesAndScores.size(); i++)
       {
+        rule.TraversalInfo() = nodesAndScores[i].travInfo;
         if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-          rule.TraversalInfo() = traversalInfo;
           Traverse(queryNode, *(nodesAndScores[i].node));
         } else {
           numPrunes += nodesAndScores.size() - i;



More information about the mlpack-git mailing list