[mlpack-svn] r13384 - in mlpack/trunk/src/mlpack/core/tree: . cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Aug 10 16:10:21 EDT 2012


Author: rcurtin
Date: 2012-08-10 16:10:21 -0400 (Fri, 10 Aug 2012)
New Revision: 13384

Removed:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_map_entry.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Rewrite dual-tree cover tree traverser.  The use of UpdateAfterRecursion() is
kludgey but it seems to work okay for now.  Also, the MapEntryType couldn't be
abstracted across both traversers because the DualTreeTraverser needs more
information.


Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-08-10 19:52:46 UTC (rev 13383)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-08-10 20:10:21 UTC (rev 13384)
@@ -14,7 +14,6 @@
   bounds.hpp
   cover_tree/cover_tree.hpp
   cover_tree/cover_tree_impl.hpp
-  cover_tree/cover_tree_map_entry.hpp
   cover_tree/first_point_is_root.hpp
   cover_tree/single_tree_traverser.hpp
   cover_tree/single_tree_traverser_impl.hpp

Deleted: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_map_entry.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_map_entry.hpp	2012-08-10 19:52:46 UTC (rev 13383)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_map_entry.hpp	2012-08-10 20:10:21 UTC (rev 13384)
@@ -1,37 +0,0 @@
-/**
- * @file cover_tree_map_entry.hpp
- * @author Ryan Curtin
- *
- * Definition of a simple struct which is used in cover tree traversal to
- * represent the data associated with a single cover tree node.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_MAP_ENTRY_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_MAP_ENTRY_HPP
-
-namespace mlpack {
-namespace tree {
-
-//! This is the structure the cover tree map will use for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct CoverTreeMapEntry
-{
-  //! The node this entry refers to.
-  CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
-  //! The score of the node.
-  double score;
-  //! The index of the parent node.
-  size_t parent;
-  //! The base case evaluation.
-  double baseCase;
-
-  //! Comparison operator.
-  bool operator<(const CoverTreeMapEntry& other) const
-  {
-    return (score < other.score);
-  }
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp	2012-08-10 19:52:46 UTC (rev 13383)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp	2012-08-10 20:10:21 UTC (rev 13384)
@@ -13,7 +13,11 @@
 namespace mlpack {
 namespace tree {
 
+//! Forward declaration of struct to be used for traversal.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct DualCoverTreeMapEntry;
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
 {
@@ -35,8 +39,9 @@
    * Helper function for traversal of the two trees.
    */
   void Traverse(CoverTree& queryNode,
-                CoverTree& referenceNode,
-                const size_t parent);
+                std::map<int, std::vector<DualCoverTreeMapEntry<
+                    MetricType, RootPointPolicy, StatisticType> > >&
+                    referenceNode);
 
   //! Get the number of pruned nodes.
   size_t NumPrunes() const { return numPrunes; }

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp	2012-08-10 19:52:46 UTC (rev 13383)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp	2012-08-10 20:10:21 UTC (rev 13384)
@@ -13,7 +13,29 @@
 namespace mlpack {
 namespace tree {
 
+//! The object placed in the map for tree traversal.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct DualCoverTreeMapEntry
+{
+  //! The node this entry refers to.
+  CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+  //! The score of the node.
+  double score;
+  //! The index of the parent reference node.
+  size_t referenceParent;
+  //! The base case evaluation.
+  double baseCase;
+  //! The query node used for the base case evaluation (and/or score).
+  CoverTree<MetricType, RootPointPolicy, StatisticType>* parentQueryNode;
+
+  //! Comparison operator, for sorting within the map.
+  bool operator<(const DualCoverTreeMapEntry& other) const
+  {
+    return (score < other.score);
+  }
+};
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
@@ -28,8 +50,23 @@
     CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
     CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
 {
-  // Start traversal with an invalid parent index.
-  Traverse(queryNode, referenceNode, size_t() - 1);
+  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+      MapEntryType;
+
+  // Start by creating a map and adding the reference node to it.
+  std::map<int, std::vector<MapEntryType> > refMap;
+
+  MapEntryType rootRefEntry;
+
+  rootRefEntry.referenceNode = &referenceNode;
+  rootRefEntry.score = 0.0; // Must recurse into.
+  rootRefEntry.referenceParent = (size_t() - 1); // Invalid index.
+  rootRefEntry.baseCase = 0.0; // Not evaluated.
+  rootRefEntry.parentQueryNode = &queryNode; // No query node was used yet.
+
+  refMap[referenceNode.Scale()].push_back(rootRefEntry);
+
+  Traverse(queryNode, refMap);
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -37,59 +74,182 @@
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::Traverse(
   CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
-  CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode,
-  const size_t parent)
+  std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+      StatisticType> > >& referenceMap)
 {
-  std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*>
-      referenceQueue;
-  std::queue<size_t> referenceParents;
+//  Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
+//      << queryNode.Scale() << "\n";
 
-  referenceQueue.push(&referenceNode);
-  referenceParents.push(parent);
+  // Convenience typedef.
+  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+      MapEntryType;
 
-  while (!referenceQueue.empty())
+  // First, reduce the maximum scale in the reference map down to the scale of
+  // the query node.
+  while ((*referenceMap.rbegin()).first > queryNode.Scale())
   {
-    CoverTree<MetricType, RootPointPolicy, StatisticType>& reference =
-        *referenceQueue.front();
-    referenceQueue.pop();
+    // Get a reference to the current largest scale.
+    std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
 
-    size_t refParent = referenceParents.front();
-    referenceParents.pop();
+    // Before traversing all the points in this scale, sort by score.
+    std::sort(scaleVector.begin(), scaleVector.end());
 
-    // Do the base case, if we need to.
-    if (refParent != reference.Point())
-      rule.BaseCase(queryNode.Point(), reference.Point());
+    // Now loop over each element.
+    for (size_t i = 0; i < scaleVector.size(); ++i)
+    {
+      // Get a reference to the current element.
+      const MapEntryType& frame = scaleVector.at(i);
 
-    if (((queryNode.Scale() < reference.Scale()) &&
-         (reference.NumChildren() != 0)) ||
-         (queryNode.NumChildren() == 0))
-    {
-      // We must descend the reference node.  Pruning happens here.
-      for (size_t i = 0; i < reference.NumChildren(); ++i)
+      CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+          frame.referenceNode;
+      CoverTree<MetricType, RootPointPolicy, StatisticType>* parentQueryNode =
+          frame.parentQueryNode;
+      const double score = frame.score;
+      const size_t refParent = frame.referenceParent;
+      const size_t refPoint = refNode->Point();
+      const size_t parentQueryPoint = parentQueryNode->Point();
+      double baseCase = frame.baseCase;
+
+//      Log::Debug << "Currently inspecting reference node " << refNode->Point()
+//          << " scale " << refNode->Scale() << " parentQueryPoint " <<
+//          parentQueryPoint << std::endl;
+
+//      Log::Debug << "Old score " << score << " with refParent " << refParent
+//          << " and queryParent " << parentQueryNode->Point() << " scale " <<
+//          parentQueryNode->Scale() << "\n";
+
+      // First we recalculate the score of this node to find if we can prune it.
+      if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
       {
-        // Can we prune?
-        if (!rule.CanPrune(queryNode, reference.Child(i)))
-        {
-          referenceQueue.push(&(reference.Child(i)));
-          referenceParents.push(reference.Point());
-        }
-        else
-        {
-          ++numPrunes;
-        }
+//        Log::Warn << "Pruned after rescore\n";
+        ++numPrunes;
+        continue;
       }
+
+      // If this is a self-child, the base case has already been evaluated.
+      // We also must ensure that the base case was evaluated with this query
+      // point.
+      if ((refPoint != refParent) || (queryNode.Point() != parentQueryPoint))
+      {
+//        Log::Warn << "Must evaluate base case\n";
+        baseCase = rule.BaseCase(queryNode.Point(), refPoint);
+//        Log::Debug << "Base case " << baseCase << std::endl;
+        rule.UpdateAfterRecursion(queryNode, *refNode); // Kludgey.
+//       Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
+//            queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
+//            std::endl;
+      }
+
+      // Create the score for the children.
+      const double childScore = rule.Score(queryNode, *refNode, baseCase);
+
+      // Now if this childScore is DBL_MAX we can prune all children.  In this
+      // recursion setup pruning is all or nothing for children.
+      if (childScore == DBL_MAX)
+      {
+//        Log::Warn << "Pruned all children.\n";
+        numPrunes += refNode->NumChildren();
+        continue;
+      }
+
+      // In the dual recursion we must add the self-leaf (as compared to the
+      // single recursion); in this case we have potentially more than one point
+      // under the query node, so we cannot prune the self-leaf.
+      for (size_t j = 0; j < refNode->NumChildren(); ++j)
+      {
+        MapEntryType newFrame;
+        newFrame.referenceNode = &refNode->Child(j);
+        newFrame.score = childScore;
+        newFrame.baseCase = baseCase;
+        newFrame.referenceParent = refPoint;
+        newFrame.parentQueryNode = &queryNode;
+
+//        Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
+//            " scale " << refNode->Child(j).Scale() << std::endl;
+
+        referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
+      }
     }
-    else
+
+    // Now clear the memory for this scale; it isn't needed anymore.
+    referenceMap.erase((*referenceMap.rbegin()).first);
+  }
+
+  // Now, reduce the scale of the query node by recursing.  But we can't recurse
+  // if the query node is a leaf node.
+  if ((queryNode.Scale() != INT_MIN) &&
+      (queryNode.Scale() >= (*referenceMap.rbegin()).first))
+  {
+    // Recurse into the non-self-children first.
+    for (size_t i = 1; i < queryNode.NumChildren(); ++i)
     {
-      // We must descend the query node.  No pruning happens here.  For the
-      // self-child, we trick the recursion into thinking that the base case
-      // has already been done (which it has).
-      if (queryNode.NumChildren() >= 1)
-        Traverse(queryNode.Child(0), reference, reference.Point());
+      std::map<int, std::vector<MapEntryType> > childMap(referenceMap);
+//      Log::Debug << "Recurse into query child " << i << ": " <<
+//          queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
+//          << "; this parent is " << queryNode.Point() << " scale " <<
+//          queryNode.Scale() << std::endl;
+      Traverse(queryNode.Child(i), childMap);
+    }
 
-      for (size_t i = 1; i < queryNode.NumChildren(); ++i)
-        Traverse(queryNode.Child(i), reference, size_t() - 1);
+    // Now we can use the existing map (without a copy) for the self-child.
+//    Log::Debug << "Recurse into query self-child " << queryNode.Child(0).Point()
+//        << " scale " << queryNode.Child(0).Scale() << "; this parent is "
+//        << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
+    Traverse(queryNode.Child(0), referenceMap);
+  }
+
+  if (queryNode.Scale() != INT_MIN)
+    return; // No need to evaluate base cases at this level.  It's all done.
+
+  // If we have made it this far, all we have is a bunch of base case
+  // evaluations to do.
+  Log::Assert((*referenceMap.begin()).first == INT_MIN);
+  Log::Assert(queryNode.Scale() == INT_MIN);
+  std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
+//  Log::Debug << "Onto base case evaluations\n";
+  for (size_t i = 0; i < pointVector.size(); ++i)
+  {
+    // Get a reference to the frame.
+    const MapEntryType& frame = pointVector[i];
+
+    CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+        frame.referenceNode;
+    CoverTree<MetricType, RootPointPolicy, StatisticType>* parentQueryNode =
+        frame.parentQueryNode;
+    const double oldScore = frame.score;
+    const size_t refParent = frame.referenceParent;
+//    Log::Debug << "Consider query " << queryNode.Point() << ", reference "
+//        << refNode->Point() << "\n";
+//    Log::Debug << "Old score " << oldScore << " with refParent " << refParent
+//        << " and parent query node " << parentQueryNode->Point() << " scale "
+//        << parentQueryNode->Scale() << std::endl;
+
+    // First, ensure that we have not already calculated the base case.
+    if ((parentQueryNode->Point() == queryNode.Point()) &&
+        (refParent == refNode->Point()))
+    {
+//      Log::Debug << "Pruned because we already did the base case.\n";
+      ++numPrunes;
+      continue;
     }
+
+    // Now, check if we can prune it.
+    const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
+
+    if (rescore == DBL_MAX)
+    {
+//      Log::Debug << "Pruned after rescoring\n";
+      ++numPrunes;
+      continue;
+    }
+
+    // If not, compute the base case.
+//    Log::Debug << "Not pruned, performing base case\n";
+    rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
+    rule.UpdateAfterRecursion(queryNode, *pointVector[i].referenceNode);
+//    Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
+//        queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
+//        std::endl;
   }
 }
 

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-10 19:52:46 UTC (rev 13383)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-10 20:10:21 UTC (rev 13384)
@@ -11,13 +11,32 @@
 // In case it hasn't been included yet.
 #include "single_tree_traverser.hpp"
 
-#include "cover_tree_map_entry.hpp"
 #include <queue>
 
 namespace mlpack {
 namespace tree {
 
+//! This is the structure the cover tree map will use for traversal.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct CoverTreeMapEntry
+{
+  //! The node this entry refers to.
+  CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+  //! The score of the node.
+  double score;
+  //! The index of the parent node.
+  size_t parent;
+  //! The base case evaluation.
+  double baseCase;
+
+  //! Comparison operator.
+  bool operator<(const CoverTreeMapEntry& other) const
+  {
+    return (score < other.score);
+  }
+};
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 CoverTree<MetricType, RootPointPolicy, StatisticType>::
 SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
@@ -35,7 +54,7 @@
   // This is a non-recursive implementation (which should be faster than a
   // recursive implementation).
   typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      QueueType;
+      MapEntryType;
 
   // We will use this map as a priority queue.  Each key represents the scale,
   // and then the vector is all the nodes in that scale which need to be
@@ -43,7 +62,7 @@
   // scale, we know that the vector for each scale is final when we get to it.
   // In addition, map is organized in such a way that rbegin() will return the
   // largest scale.
-  std::map<int, std::vector<QueueType> > mapQueue;
+  std::map<int, std::vector<MapEntryType> > mapQueue;
 
   // Manually add the children of the first node.  These cannot be pruned
   // anyway.
@@ -68,7 +87,7 @@
 
     for (/* i was set above. */; i < referenceNode.NumChildren(); ++i)
     {
-      QueueType newFrame;
+      MapEntryType newFrame;
       newFrame.node = &referenceNode.Child(i);
       newFrame.score = rootChildScore;
       newFrame.baseCase = rootBaseCase;
@@ -80,23 +99,23 @@
   }
 
   // Now begin the iteration through the map.
-  typename std::map<int, std::vector<QueueType> >::reverse_iterator rit =
+  typename std::map<int, std::vector<MapEntryType> >::reverse_iterator rit =
       mapQueue.rbegin();
 
   // We will treat the leaves differently (below).
   while ((*rit).first != INT_MIN)
   {
     // Get a reference to the current scale.
-    std::vector<QueueType>& scaleVector = (*rit).second;
+    std::vector<MapEntryType>& scaleVector = (*rit).second;
 
-    // Before beginning all the points in this scale, sort by score.
+    // Before traversing all the points in this scale, sort by score.
     std::sort(scaleVector.begin(), scaleVector.end());
 
     // Now loop over each element.
     for (size_t i = 0; i < scaleVector.size(); ++i)
     {
       // Get a reference to the current element.
-      const QueueType& frame = scaleVector.at(i);
+      const MapEntryType& frame = scaleVector.at(i);
 
       CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
       const double score = frame.score;
@@ -136,7 +155,7 @@
 
       for (/* j is already set. */; j < node->NumChildren(); ++j)
       {
-        QueueType newFrame;
+        MapEntryType newFrame;
         newFrame.node = &node->Child(j);
         newFrame.score = childScore;
         newFrame.baseCase = baseCase;
@@ -153,7 +172,7 @@
   // Now deal with the leaves.
   for (size_t i = 0; i < mapQueue[INT_MIN].size(); ++i)
   {
-    const QueueType& frame = mapQueue[INT_MIN].at(i);
+    const MapEntryType& frame = mapQueue[INT_MIN].at(i);
 
     CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
     const double score = frame.score;




More information about the mlpack-svn mailing list