[mlpack-svn] r13367 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 22:02:47 EDT 2012


Author: rcurtin
Date: 2012-08-07 22:02:47 -0400 (Tue, 07 Aug 2012)
New Revision: 13367

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Refactor single tree traverser again; don't use priority_queue which is slow.
Instead exploit some nice properties which allow us to use a map of vectors.  It
seems as though this implementation is about twice as fast.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-08 01:55:23 UTC (rev 13366)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-08 02:02:47 UTC (rev 13367)
@@ -32,8 +32,7 @@
   //! Comparison operator.
   bool operator<(const CoverTreeQueueEntry& other) const
   {
-    return ((node->Scale() < other.node->Scale()) ||
-            ((node->Scale() == other.node->Scale()) && (score < other.score)));
+    return (score < other.score);
   }
 };
 
@@ -58,139 +57,140 @@
   // recursive implementation).
   typedef CoverTreeQueueEntry<MetricType, RootPointPolicy, StatisticType>
       QueueType;
-  std::priority_queue<QueueType> pointQueue;
 
-  // Unsorted list of leaves we have to look through.
-  std::queue<QueueType> leafQueue;
+  // We will use this map as a priority queue.  Each key represents the scale,
+  // and then the vector is all the nodes in that scale which need to be
+  // investigated.  Because no point in a scale can add a point in its own
+  // scale, we know that the vector for each scale is final when we get to it.
+  // In addition, map is organized in such a way that rbegin() will return the
+  // largest scale.
+  std::map<int, std::vector<QueueType> > mapQueue;
 
-  QueueType first;
-  first.node = &referenceNode;
-  first.score = 0.0;
-  first.parent = (size_t() - 1); // Invalid index.
-  first.baseCase = -1.0;
-  pointQueue.push(first);
+  // Manually add the children of the first node.  These cannot be pruned
+  // anyway.
+  double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
+    
+  // Create the score for the children.
+  double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
 
-  Log::Warn << "Beginning recursion for " << queryIndex << std::endl;
-
-  while (!pointQueue.empty())
+  if (rootChildScore == DBL_MAX)
   {
-    QueueType frame = pointQueue.top();
+    numPrunes += referenceNode.NumChildren();
+  }
+  else
+  {
+    // Don't add the self-leaf.
+    size_t i = 0;
+    if (referenceNode.Child(0).NumChildren() == 0)
+    {
+      ++numPrunes;
+      i = 1;
+    }
 
-    CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
-    const double score = frame.score;
-    const size_t parent = frame.parent;
-    const size_t point = node->Point();
-    double baseCase = frame.baseCase;
+    for (/* i was set above. */; i < referenceNode.NumChildren(); ++i)
+    {
+      QueueType newFrame;
+      newFrame.node = &referenceNode.Child(i);
+      newFrame.score = rootChildScore;
+      newFrame.baseCase = rootBaseCase;
+      newFrame.parent = referenceNode.Point();
 
-    Log::Debug << "Current point is " << node->Point() << " and scale "
-        << node->Scale() << ".\n";
+      // Put it into the map.
+      mapQueue[newFrame.node->Scale()].push_back(newFrame);
+    }
+  }
 
-    pointQueue.pop();
+  // Now begin the iteration through the map.
+  typename std::map<int, std::vector<QueueType> >::reverse_iterator rit =
+      mapQueue.rbegin();
 
-    // First we (re)calculate the score of this node to find if we can prune it.
-    Log::Debug << "Before rescoring, score is " << score << " and base case of "
-        << "parent is " << baseCase << std::endl;
-    double actualScore = rule.Rescore(queryIndex, *node, score);
+  // We will treat the leaves differently (below).
+  while ((*rit).first != INT_MIN)
+  {
+    // Get a reference to the current scale.
+    std::vector<QueueType>& scaleVector = (*rit).second;
 
-//    Log::Debug << "Actual score is " << actualScore << ".\n";
+    // Before beginning all the points in this scale, sort by score.
+    std::sort(scaleVector.begin(), scaleVector.end());
 
-    if (actualScore == DBL_MAX)
+    // Now loop over each element.
+    for (size_t i = 0; i < scaleVector.size(); ++i)
     {
-      // Prune this node.
-      Log::Debug << "Pruning after re-scoring (original score " << score << ")."
-          << std::endl;
-      ++numPrunes;
-      continue; // Skip to next in queue.
-    }
+      // Get a reference to the current element.
+      const QueueType& frame = scaleVector.at(i);
 
-    // If we are a self-child, the base case has already been evaluated.
-    if (point != parent)
-    {
-      baseCase = rule.BaseCase(queryIndex, point);
-      Log::Debug << "Base case between " << queryIndex << " and " << point <<
-          " evaluates to " << baseCase << ".\n";
-    }
-    else
-    {
-      Log::Debug << "Base case between " << queryIndex << " and " << point <<
-          " already known to be " << baseCase << ".\n";
-    }
+      CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+      const double score = frame.score;
+      const size_t parent = frame.parent;
+      const size_t point = node->Point();
+      double baseCase = frame.baseCase;
 
-    // Create the score for the children.
-    double childScore = rule.Score(queryIndex, *node, baseCase);
+      // First we recalculate the score of this node to find if we can prune it.
+      if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
+      {
+        ++numPrunes;
+        continue;
+      }
 
-    // Now if the childScore is DBL_MAX we can prune all children.  In this
-    // recursion setup pruning is all or nothing for children.
-    if (childScore == DBL_MAX)
-    {
-      Log::Debug << "Pruning all children.\n";
-      numPrunes += node->NumChildren();
-    }
-    else
-    {
-      for (size_t i = 0; i < node->NumChildren(); ++i)
+      // If we are a self-child, the base case has already been evaluated.
+      if (point != parent)
+        baseCase = rule.BaseCase(queryIndex, point);
+
+      // Create the score for the children.
+      const double childScore = rule.Score(queryIndex, *node, baseCase);
+
+      // Now if this childScore is DBL_MAX we can prune all children.  In this
+      // recursion setup pruning is all or nothing for children.
+      if (childScore == DBL_MAX)
       {
+        numPrunes += node->NumChildren();
+        continue;
+      }
+
+      // Don't add the self-leaf.
+      size_t j = 0;
+      if (node->Child(0).NumChildren() == 0)
+      {
+        ++numPrunes;
+        j = 1;
+      }
+
+      for (/* j is already set. */; j < node->NumChildren(); ++j)
+      {
         QueueType newFrame;
-        newFrame.node = &node->Child(i);
+        newFrame.node = &node->Child(j);
         newFrame.score = childScore;
         newFrame.baseCase = baseCase;
         newFrame.parent = point;
 
-        // Put it into the regular priority queue if it has children.
-        if (newFrame.node->NumChildren() > 0)
-        {
-          Log::Debug << "Push back child " << i << ": point " <<
-              newFrame.node->Point() << ", scale " << newFrame.node->Scale()
-              << ".\n";
-          pointQueue.push(newFrame);
-        }
-        else if ((newFrame.node->NumChildren() == 0) && (i > 0))
-        {
-          // We don't add the self-leaf to the leaf queue (it can't possibly
-          // help).
-          Log::Debug << "Push back child " << i << ": point " <<
-              newFrame.node->Point() << ", scale " << newFrame.node->Scale()
-              << ".\n";
-          leafQueue.push(newFrame);
-        }
-        else
-        {
-          Log::Debug << "Prune self-leaf point " << point << ".\n";
-          ++numPrunes;
-        }
+        mapQueue[newFrame.node->Scale()].push_back(newFrame);
       }
     }
+
+    // Now clear the memory for this scale; it isn't needed anymore.
+    mapQueue.erase((*rit).first);
   }
 
-  // Now look through all the leaves.
-  while (!leafQueue.empty())
+  // Now deal with the leaves.
+  for (size_t i = 0; i < mapQueue[INT_MIN].size(); ++i)
   {
-    QueueType frame = leafQueue.front();
+    const QueueType& frame = mapQueue[INT_MIN].at(i);
 
     CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
     const double score = frame.score;
     const size_t point = node->Point();
 
-    Log::Debug << "Inspecting leaf " << point << " with score " << score <<
-      "\n";
-
-    leafQueue.pop();
-
     // First, recalculate the score of this node to find if we can prune it.
     double actualScore = rule.Rescore(queryIndex, *node, score);
 
     if (actualScore == DBL_MAX)
     {
-      Log::Debug << "Pruned before base case.\n";
       ++numPrunes;
       continue;
     }
 
-    // There are no self-leaves in this queue, so the only thing left to do is
-    // evaluate the base case.
+    // There are no self-leaves; evaluate the base case.
     const double baseCase = rule.BaseCase(queryIndex, point);
-    Log::Debug << "Base case between " << queryIndex << " and " << point <<
-        " evaluates to " << baseCase << ".\n";
   }
 }
 




More information about the mlpack-svn mailing list