[mlpack-svn] r13340 - mlpack/trunk/src/mlpack/core/tree/binary_space_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 14:18:21 EDT 2012


Author: rcurtin
Date: 2012-08-06 14:18:21 -0400 (Mon, 06 Aug 2012)
New Revision: 13340

Modified:
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
Log:
Refactor BinarySpaceTree traversal to use Score() and Rescore() and thereby
avoid unnecessary repeated calculations (so I hope).


Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	2012-08-06 18:17:55 UTC (rev 13339)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	2012-08-06 18:18:21 UTC (rev 13340)
@@ -30,13 +30,6 @@
     BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
     BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
 {
-  // Check if pruning can occur.
-  if (rule.CanPrune(queryNode, referenceNode))
-  {
-    ++numPrunes;
-    return;
-  }
-
   // If both are leaves, we must evaluate the base case.
   if (queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
@@ -47,56 +40,175 @@
   }
   else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
   {
-    // We have to recurse down the query node.
-    if (rule.LeftFirst(referenceNode, queryNode))
-    {
-      Traverse(*queryNode.Left(), referenceNode);
-      Traverse(*queryNode.Right(), referenceNode);
-    }
-    else
-    {
-      Traverse(*queryNode.Right(), referenceNode);
-      Traverse(*queryNode.Left(), referenceNode);
-    }
+    // We have to recurse down the query node.  In this case the recursion order
+    // does not matter.
+    Traverse(*queryNode.Left(), referenceNode);
+    Traverse(*queryNode.Right(), referenceNode);
   }
   else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
   {
-    // We have to recurse down the reference node.
-    if (rule.LeftFirst(queryNode, referenceNode))
+    // We have to recurse down the reference node.  In this case the recursion
+    // order does matter.
+    double leftScore = rule.Score(queryNode, *referenceNode.Left());
+    double rightScore = rule.Score(queryNode, *referenceNode.Right());
+
+    if (leftScore < rightScore)
     {
+      // Recurse to the left.
       Traverse(queryNode, *referenceNode.Left());
-      Traverse(queryNode, *referenceNode.Right());
+
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+      if (rightScore != DBL_MAX)
+        Traverse(queryNode, *referenceNode.Right());
+      else
+        ++numPrunes;
     }
-    else
+    else if (rightScore < leftScore)
     {
+      // Recurse to the right.
       Traverse(queryNode, *referenceNode.Right());
-      Traverse(queryNode, *referenceNode.Left());
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+      if (leftScore != DBL_MAX)
+        Traverse(queryNode, *referenceNode.Left());
+      else
+        ++numPrunes;
     }
+    else // leftScore is equal to rightScore.
+    {
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.
+        Traverse(queryNode, *referenceNode.Left());
+
+        rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+          Traverse(queryNode, *referenceNode.Right());
+        else
+          ++numPrunes;
+      }
+    }
   }
   else
   {
-    // We have to recurse down both.
-    if (rule.LeftFirst(*queryNode.Left(), referenceNode))
+    // We have to recurse down both query and reference nodes.  Because the
+    // query descent order does not matter, we will go to the left query child
+    // first.
+    double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+    double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+
+    if (leftScore < rightScore)
     {
+      // Recurse to the left.
       Traverse(*queryNode.Left(), *referenceNode.Left());
+
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+          rightScore);
+
+      if (rightScore != DBL_MAX)
+        Traverse(*queryNode.Left(), *referenceNode.Right());
+      else
+        ++numPrunes;
+    }
+    else if (rightScore < leftScore)
+    {
+      // Recurse to the right.
       Traverse(*queryNode.Left(), *referenceNode.Right());
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+          leftScore);
+
+      if (leftScore != DBL_MAX)
+        Traverse(*queryNode.Left(), *referenceNode.Left());
+      else
+        ++numPrunes;
     }
     else
     {
-      Traverse(*queryNode.Left(), *referenceNode.Right());
-      Traverse(*queryNode.Left(), *referenceNode.Left());
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.
+        Traverse(*queryNode.Left(), *referenceNode.Left());
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+          Traverse(*queryNode.Left(), *referenceNode.Right());
+        else
+          ++numPrunes;
+      }
     }
 
-    // Now recurse to the right query child.
-    if (rule.LeftFirst(*queryNode.Right(), referenceNode))
+    // Now recurse down the right query node.
+    leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+    rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+
+    if (leftScore < rightScore)
     {
+      // Recurse to the left.
       Traverse(*queryNode.Right(), *referenceNode.Left());
+
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+          rightScore);
+
+      if (rightScore != DBL_MAX)
+        Traverse(*queryNode.Right(), *referenceNode.Right());
+      else
+        ++numPrunes;
+    }
+    else if (rightScore < leftScore)
+    {
+      // Recurse to the right.
       Traverse(*queryNode.Right(), *referenceNode.Right());
+
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+          leftScore);
+
+      if (leftScore != DBL_MAX)
+        Traverse(*queryNode.Right(), *referenceNode.Left());
+      else
+        ++numPrunes;
     }
     else
     {
-      Traverse(*queryNode.Right(), *referenceNode.Right());
-      Traverse(*queryNode.Right(), *referenceNode.Left());
+      if (leftScore == DBL_MAX)
+      {
+        numPrunes += 2;
+      }
+      else
+      {
+        // Choose the left first.
+        Traverse(*queryNode.Right(), *referenceNode.Right());
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+          Traverse(*queryNode.Right(), *referenceNode.Right());
+        else
+          ++numPrunes;
+      }
     }
   }
 

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp	2012-08-06 18:17:55 UTC (rev 13339)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp	2012-08-06 18:18:21 UTC (rev 13340)
@@ -32,42 +32,63 @@
     const size_t queryIndex,
     BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
 {
-  // This is a non-recursive implementation (which should be faster).
+  // If we are a leaf, run the base case as necessary.
+  if (referenceNode.IsLeaf())
+  {
+    for (size_t i = referenceNode.Begin(); i < referenceNode.End(); ++i)
+      rule.BaseCase(queryIndex, i);
+  }
+  else
+  {
+    // If either score is DBL_MAX, we do not recurse into that node.
+    double leftScore = rule.Score(queryIndex, *referenceNode.Left());
+    double rightScore = rule.Score(queryIndex, *referenceNode.Right());
 
-  // The stack of points to look at.
-  std::stack<BinarySpaceTree<BoundType, StatisticType, MatType>*> pointStack;
-  pointStack.push(&referenceNode);
+    if (leftScore < rightScore)
+    {
+      // Recurse to the left.
+      Traverse(queryIndex, *referenceNode.Left());
 
-  while (!pointStack.empty())
-  {
-    BinarySpaceTree<BoundType, StatisticType, MatType>* node = pointStack.top();
-    pointStack.pop();
+      // Is it still valid to recurse to the right?
+      rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore);
 
-    // Check if we can prune this node.
-    if (rule.CanPrune(queryIndex, *node))
+      if (rightScore != DBL_MAX)
+        Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right.
+      else
+        ++numPrunes;
+    }
+    else if (rightScore < leftScore)
     {
-      ++numPrunes;
-      continue;
-    }
+      // Recurse to the right.
+      Traverse(queryIndex, *referenceNode.Right());
 
-    // If we are a leaf, run the base case as necessary.
-    if (node->IsLeaf())
-    {
-      for (size_t i = node->Begin(); i < node->End(); ++i)
-        rule.BaseCase(queryIndex, i);
+      // Is it still valid to recurse to the left?
+      leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore);
+
+      if (leftScore != DBL_MAX)
+        Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left.
+      else
+        ++numPrunes;
     }
-    else
+    else // leftScore is equal to rightScore.
     {
-      // Otherwise recurse by distance.
-      if (rule.LeftFirst(queryIndex, *node))
+      if (leftScore == DBL_MAX)
       {
-        pointStack.push(node->Right());
-        pointStack.push(node->Left());
+        numPrunes += 2; // Pruned both left and right.
       }
       else
       {
-        pointStack.push(node->Left());
-        pointStack.push(node->Right());
+        // Choose the left first.
+        Traverse(queryIndex, *referenceNode.Left());
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(queryIndex, *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+          Traverse(queryIndex, *referenceNode.Right());
+        else
+          ++numPrunes;
       }
     }
   }




More information about the mlpack-svn mailing list