[mlpack-svn] r17084 - mlpack/trunk/src/mlpack/core/tree/rectangle_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 20 13:23:38 EDT 2014


Author: andrewmw94
Date: Wed Aug 20 13:23:37 2014
New Revision: 17084

Log:
dual tree traverser bug fixes.

Modified:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp	Wed Aug 20 13:23:37 2014
@@ -67,6 +67,7 @@
   public:
     RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
     double score;
+    typename RuleType::TraversalInfoType travInfo;
   };
 
   static bool nodeComparator(const NodeAndScore& obj1,

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp	Wed Aug 20 13:23:37 2014
@@ -76,13 +76,7 @@
   else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
     // We only need to traverse down the query node.  Order doesn't matter here.
-    ++numScores;
-    if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
-      Traverse(queryNode.Child(0), referenceNode);
-    else
-      numPrunes++;
-    
-    for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+    for(size_t i = 0; i < queryNode.NumChildren(); ++i)
     {
       // Before recursing, we have to set the traversal information correctly.
       rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@
     std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
     for(int i = 0; i < referenceNode.NumChildren(); i++)
     {
+      rule.TraversalInfo() = traversalInfo;
       nodesAndScores[i].node = referenceNode.Children()[i];
       nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+      nodesAndScores[i].travInfo = rule.TraversalInfo();
     }
     std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
     numScores += nodesAndScores.size();
     
     for(int i = 0; i < nodesAndScores.size(); i++)
     {
+      rule.TraversalInfo() = nodesAndScores[i].travInfo;
       if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-        rule.TraversalInfo() = traversalInfo;
         Traverse(queryNode, *(nodesAndScores[i].node));
       } else {
         numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@
       std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
       for(int i = 0; i < referenceNode.NumChildren(); i++)
       {
+        rule.TraversalInfo() = traversalInfo;
         nodesAndScores[i].node = referenceNode.Children()[i];
         nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+        nodesAndScores[i].travInfo = rule.TraversalInfo();
       }
       std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
       numScores += nodesAndScores.size();
     
       for(int i = 0; i < nodesAndScores.size(); i++)
       {
+        rule.TraversalInfo() = nodesAndScores[i].travInfo;
         if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-          rule.TraversalInfo() = traversalInfo;
           Traverse(queryNode, *(nodesAndScores[i].node));
         } else {
           numPrunes += nodesAndScores.size() - i;



More information about the mlpack-svn mailing list