[mlpack-svn] r17077 - in mlpack/trunk/src/mlpack: core/tree/rectangle_tree methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 19 23:21:24 EDT 2014


Author: andrewmw94
Date: Tue Aug 19 23:21:23 2014
New Revision: 17077

Log:
Dual tree traverser.

Modified:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp	Tue Aug 19 23:21:23 2014
@@ -61,6 +61,20 @@
   size_t& NumBaseCases() { return numBaseCases; }
 
  private:
+   
+  //We use this struct and this function to make the sorting and scoring easy and efficient:
+  class NodeAndScore {
+  public:
+    RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
+    double score;
+  };
+
+  static bool nodeComparator(const NodeAndScore& obj1,
+                      const NodeAndScore& obj2)
+  {
+    return obj1.score < obj2.score;
+  }
+  
   //! Reference to the rules with which the trees will be traversed.
   RuleType& rule;
 

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp	Tue Aug 19 23:21:23 2014
@@ -25,7 +25,10 @@
 RectangleTree<SplitType, DescentType, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     rule(rule),
-    numPrunes(0)
+    numPrunes(0),
+    numVisited(0),
+    numScores(0),
+    numBaseCases(0)
 { /* Nothing to do */ }
 
 template<typename SplitType,
@@ -34,13 +37,117 @@
          typename MatType>
 template<typename RuleType>
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
-DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
-		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+DualTreeTraverser<RuleType>::Traverse(
+    RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+    RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
 {
-  //Do nothing.  Just here to prevent warnings.
-  if(queryNode.NumDescendants() > referenceNode.NumDescendants())
-    return;
-  return;
+  // Increment the visit counter.
+  ++numVisited;
+  
+  // Store the current traversal info.
+  traversalInfo = rule.TraversalInfo();
+  
+  // We now have four options.
+  // 1)  Both nodes are leaf nodes.
+  // 2)  Only the reference node is a leaf node.
+  // 3)  Only the query node is a leaf node.
+  // 4)  Niether node is a leaf node.
+  // We go through those options in that order.
+  
+  if(queryNode.IsLeaf() && referenceNode.IsLeaf())
+  {
+    // Evaluate the base case.  Do the query points on the outside so we can possibly
+    // prune the reference node for that particular point.
+    for(size_t query = 0; query < queryNode.Count(); ++query)
+    {
+      // Restore the traversal information.
+      rule.TraversalInfo() = traversalInfo;
+      const double childScore = rule.Score(queryNode.Points()[query], referenceNode);
+      
+      if(childScore == DBL_MAX)
+        continue;  // This point doesn't require a search in this reference node.
+      
+      for(size_t ref = 0; ref < referenceNode.Count(); ++ref)
+        rule.BaseCase(queryNode.Points()[query], referenceNode.Points()[ref]);
+      
+      numBaseCases += referenceNode.Count();
+    }
+  }
+  else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
+  {
+    // We only need to traverse down the query node.  Order doesn't matter here.
+    ++numScores;
+    if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
+      Traverse(queryNode.Child(0), referenceNode);
+    else
+      numPrunes++;
+    
+    for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+    {
+      // Before recursing, we have to set the traversal information correctly.
+      rule.TraversalInfo() = traversalInfo;
+      ++numScores;
+      if(rule.Score(queryNode.Child(i), referenceNode) < DBL_MAX)
+        Traverse(queryNode.Child(i), referenceNode);
+      else
+        numPrunes++;
+    }       
+  }
+  else if(queryNode.IsLeaf() && !referenceNode.IsLeaf())
+  {
+    // We only need to traverse down the reference node.  Order does matter here.
+    
+    // We sort the children of the reference node by their scores.    
+    std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+    for(int i = 0; i < referenceNode.NumChildren(); i++)
+    {
+      nodesAndScores[i].node = referenceNode.Children()[i];
+      nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+    }
+    std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+    numScores += nodesAndScores.size();
+    
+    for(int i = 0; i < nodesAndScores.size(); i++)
+    {
+      if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+        rule.TraversalInfo() = traversalInfo;
+        Traverse(queryNode, *(nodesAndScores[i].node));
+      } else {
+        numPrunes += nodesAndScores.size() - i;
+        break;
+      }
+    }
+  }
+  else
+  {
+    // We need to traverse down both the reference and the query trees.
+    // We loop through all of the query nodes, and for each of them, we
+    // loop through the reference nodes to see where we need to descend.
+    
+    for(int j = 0; j < queryNode.NumChildren(); j++)
+    {
+      // We sort the children of the reference node by their scores.    
+      std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+      for(int i = 0; i < referenceNode.NumChildren(); i++)
+      {
+        nodesAndScores[i].node = referenceNode.Children()[i];
+        nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+      }
+      std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+      numScores += nodesAndScores.size();
+    
+      for(int i = 0; i < nodesAndScores.size(); i++)
+      {
+        if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+          rule.TraversalInfo() = traversalInfo;
+          Traverse(queryNode, *(nodesAndScores[i].node));
+        } else {
+          numPrunes += nodesAndScores.size() - i;
+          break;
+        }
+      }
+    }
+  }
 }
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp	Tue Aug 19 23:21:23 2014
@@ -44,10 +44,12 @@
   //! Modify the number of prunes.
   size_t& NumPrunes() { return numPrunes; }
 
-  // We use this struct and this function to make the sorting and scoring easy
+ private:
+  
+  // We use this class and this function to make the sorting and scoring easy
   // and efficient:
-  struct NodeAndScore
-  {
+  class NodeAndScore {
+   public:
     RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
     double score;
   };
@@ -57,7 +59,6 @@
     return obj1.score < obj2.score;
   }
 
- private:
   //! Reference to the rules with which the tree will be traversed.
   RuleType& rule;
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	Tue Aug 19 23:21:23 2014
@@ -131,8 +131,8 @@
     Log::Warn << "--cover_tree overrides --r_tree." << endl;
   } else if (!singleMode && CLI::HasParam("r_tree"))  // R_tree requires single mode.
   {
-    Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
-    singleMode = true;
+//     Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+//     singleMode = true;
   }
   
   if (naive)



More information about the mlpack-svn mailing list