[mlpack-svn] r16228 - mlpack/trunk/src/mlpack/core/tree/binary_space_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 6 15:18:58 EST 2014


Author: rcurtin
Date: Thu Feb  6 15:18:58 2014
New Revision: 16228

Log:
Modify BinarySpaceTree::DualTreeTraverser to properly handle TraversalInfo
objects.


Modified:
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	Thu Feb  6 15:18:58 2014
@@ -28,12 +28,17 @@
   DualTreeTraverser(RuleType& rule);
 
   /**
-   * Traverse the two trees.  This does not reset the number of prunes.
+   * Traverse the two trees.  This does not reset the number of prunes.  If you
+   * are starting a traversal, the score for the parent node combination is
+   * irrelevant and can be left as 0 and thus does not need to be specified.
    *
    * @param queryNode The query node to be traversed.
    * @param referenceNode The reference node to be traversed.
+   * @param score The score of the current node combination.
    */
-  void Traverse(BinarySpaceTree& queryNode, BinarySpaceTree& referenceNode);
+  void Traverse(BinarySpaceTree& queryNode,
+                BinarySpaceTree& referenceNode,
+                const double score = 0.0);
 
   //! Get the number of prunes.
   size_t NumPrunes() const { return numPrunes; }
@@ -70,6 +75,10 @@
 
   //! The number of times a base case was calculated.
   size_t numBaseCases;
+
+  //! Traversal information, held in the class so that it isn't continually
+  //! being reallocated.
+  typename RuleType::TraversalInfoType traversalInfo;
 };
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	Thu Feb  6 15:18:58 2014
@@ -31,11 +31,15 @@
 void BinarySpaceTree<BoundType, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::Traverse(
     BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
-    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode,
+    const double score /* = 0.0 */)
 {
   // Increment the visit counter.
   ++numVisited;
 
+  // Store the current traversal info.
+  traversalInfo = rule.TraversalInfo();
+
   // If both are leaves, we must evaluate the base case.
   if (queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
@@ -43,10 +47,12 @@
     for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
     {
       // See if we need to investigate this point (this function should be
-      // implemented for the single-tree recursion too).
-      const double score = rule.Score(query, referenceNode);
+      // implemented for the single-tree recursion too).  Restore the traversal
+      // information first.
+      rule.TraversalInfo() = traversalInfo;
+      const double childScore = rule.Score(query, referenceNode);
 
-      if (score == DBL_MAX)
+      if (childScore == DBL_MAX)
         continue; // We can't improve this particular point.
 
       for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
@@ -59,53 +65,69 @@
   {
     // We have to recurse down the query node.  In this case the recursion order
     // does not matter.
-    double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+    const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
     ++numScores;
 
     if (leftScore != DBL_MAX)
-      Traverse(*queryNode.Left(), referenceNode);
+      Traverse(*queryNode.Left(), referenceNode, leftScore);
     else
       ++numPrunes;
 
-    double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+    // Before recursing, we have to set the traversal information correctly.
+    rule.TraversalInfo() = traversalInfo;
+    const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
     ++numScores;
 
     if (rightScore != DBL_MAX)
-      Traverse(*queryNode.Right(), referenceNode);
+      Traverse(*queryNode.Right(), referenceNode, rightScore);
     else
       ++numPrunes;
   }
   else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
   {
     // We have to recurse down the reference node.  In this case the recursion
-    // order does matter.
+    // order does matter.  Before recursing, though, we have to set the
+    // traversal information correctly.
     double leftScore = rule.Score(queryNode, *referenceNode.Left());
+    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+    rule.TraversalInfo() = traversalInfo;
     double rightScore = rule.Score(queryNode, *referenceNode.Right());
     numScores += 2;
 
     if (leftScore < rightScore)
     {
-      // Recurse to the left.
-      Traverse(queryNode, *referenceNode.Left());
+      // Recurse to the left.  Restore the left traversal info.  Store the right
+      // traversal info.
+      traversalInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = leftInfo;
+      Traverse(queryNode, *referenceNode.Left(), leftScore);
 
       // Is it still valid to recurse to the right?
       rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
 
       if (rightScore != DBL_MAX)
-        Traverse(queryNode, *referenceNode.Right());
+      {
+        // Restore the right traversal info.
+        rule.TraversalInfo() = traversalInfo;
+        Traverse(queryNode, *referenceNode.Right(), rightScore);
+      }
       else
         ++numPrunes;
     }
     else if (rightScore < leftScore)
     {
       // Recurse to the right.
-      Traverse(queryNode, *referenceNode.Right());
+      Traverse(queryNode, *referenceNode.Right(), rightScore);
 
       // Is it still valid to recurse to the left?
       leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
 
       if (leftScore != DBL_MAX)
-        Traverse(queryNode, *referenceNode.Left());
+      {
+        // Restore the left traversal info.
+        rule.TraversalInfo() = leftInfo;
+        Traverse(queryNode, *referenceNode.Left(), leftScore);
+      }
       else
         ++numPrunes;
     }
@@ -117,14 +139,21 @@
       }
       else
       {
-        // Choose the left first.
-        Traverse(queryNode, *referenceNode.Left());
+        // Choose the left first.  Restore the left traversal info.  Store the
+        // right traversal info.
+        traversalInfo = rule.TraversalInfo();
+        rule.TraversalInfo() = leftInfo;
+        Traverse(queryNode, *referenceNode.Left(), leftScore);
 
         rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
             rightScore);
 
         if (rightScore != DBL_MAX)
-          Traverse(queryNode, *referenceNode.Right());
+        {
+          // Restore the right traversal info.
+          rule.TraversalInfo() = traversalInfo;
+          Traverse(queryNode, *referenceNode.Right(), rightScore);
+        }
         else
           ++numPrunes;
       }
@@ -134,36 +163,51 @@
   {
     // We have to recurse down both query and reference nodes.  Because the
     // query descent order does not matter, we will go to the left query child
-    // first.
+    // first.  Before recursing, we have to set the traversal information
+    // correctly.
     double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+    rule.TraversalInfo() = traversalInfo;
     double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+    typename RuleType::TraversalInfoType rightInfo;
     numScores += 2;
 
     if (leftScore < rightScore)
     {
-      // Recurse to the left.
-      Traverse(*queryNode.Left(), *referenceNode.Left());
+      // Recurse to the left.  Restore the left traversal info.  Store the right
+      // traversal info.
+      rightInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = leftInfo;
+      Traverse(*queryNode.Left(), *referenceNode.Left(), leftScore);
 
       // Is it still valid to recurse to the right?
       rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
           rightScore);
 
       if (rightScore != DBL_MAX)
-        Traverse(*queryNode.Left(), *referenceNode.Right());
+      {
+        // Restore the right traversal info.
+        rule.TraversalInfo() = rightInfo;
+        Traverse(*queryNode.Left(), *referenceNode.Right(), rightScore);
+      }
       else
         ++numPrunes;
     }
     else if (rightScore < leftScore)
     {
       // Recurse to the right.
-      Traverse(*queryNode.Left(), *referenceNode.Right());
+      Traverse(*queryNode.Left(), *referenceNode.Right(), rightScore);
 
       // Is it still valid to recurse to the left?
       leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
           leftScore);
 
       if (leftScore != DBL_MAX)
-        Traverse(*queryNode.Left(), *referenceNode.Left());
+      {
+        // Restore the left traversal info.
+        rule.TraversalInfo() = leftInfo;
+        Traverse(*queryNode.Left(), *referenceNode.Left(), leftScore);
+      }
       else
         ++numPrunes;
     }
@@ -175,7 +219,10 @@
       }
       else
       {
-        // Choose the left first.
+        // Choose the left first.  Restore the left traversal info and store the
+        // right traversal info.
+        rightInfo = rule.TraversalInfo();
+        rule.TraversalInfo() = leftInfo;
         Traverse(*queryNode.Left(), *referenceNode.Left());
 
         // Is it still valid to recurse to the right?
@@ -183,42 +230,62 @@
             rightScore);
 
         if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal information.
+          rule.TraversalInfo() = rightInfo;
           Traverse(*queryNode.Left(), *referenceNode.Right());
+        }
         else
           ++numPrunes;
       }
     }
 
+    // Restore the main traversal information.
+    rule.TraversalInfo() = traversalInfo;
+
     // Now recurse down the right query node.
     leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+    leftInfo = rule.TraversalInfo();
+    rule.TraversalInfo() = traversalInfo;
     rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
     numScores += 2;
 
     if (leftScore < rightScore)
     {
-      // Recurse to the left.
-      Traverse(*queryNode.Right(), *referenceNode.Left());
+      // Recurse to the left.  Restore the left traversal info.  Store the right
+      // traversal info.
+      rightInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = leftInfo;
+      Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
 
       // Is it still valid to recurse to the right?
       rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
           rightScore);
 
       if (rightScore != DBL_MAX)
-        Traverse(*queryNode.Right(), *referenceNode.Right());
+      {
+        // Restore the right traversal info.
+        rule.TraversalInfo() = rightInfo;
+        Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
+      }
       else
         ++numPrunes;
     }
     else if (rightScore < leftScore)
     {
       // Recurse to the right.
-      Traverse(*queryNode.Right(), *referenceNode.Right());
+      Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
 
       // Is it still valid to recurse to the left?
       leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
           leftScore);
 
       if (leftScore != DBL_MAX)
-        Traverse(*queryNode.Right(), *referenceNode.Left());
+      {
+        // Restore the left traversal info.
+        rule.TraversalInfo() = leftInfo;
+        Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
+      }
       else
         ++numPrunes;
     }
@@ -230,15 +297,22 @@
       }
       else
       {
-        // Choose the left first.
-        Traverse(*queryNode.Right(), *referenceNode.Left());
+        // Choose the left first.  Restore the left traversal info.  Store the
+        // right traversal info.
+        rightInfo = rule.TraversalInfo();
+        rule.TraversalInfo() = leftInfo;
+        Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
 
         // Is it still valid to recurse to the right?
         rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
             rightScore);
 
         if (rightScore != DBL_MAX)
-          Traverse(*queryNode.Right(), *referenceNode.Right());
+        {
+          // Restore the right traversal info.
+          rule.TraversalInfo() = rightInfo;
+          Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
+        }
         else
           ++numPrunes;
       }
@@ -250,4 +324,3 @@
 }; // namespace mlpack
 
 #endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-



More information about the mlpack-svn mailing list