[mlpack-svn] r13364 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 18:40:38 EDT 2012


Author: rcurtin
Date: 2012-08-07 18:40:37 -0400 (Tue, 07 Aug 2012)
New Revision: 13364

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Log:
Revamp single tree traverser... again.  For now it has debugging output.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-07 22:37:20 UTC (rev 13363)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	2012-08-07 22:40:37 UTC (rev 13364)
@@ -16,7 +16,30 @@
 namespace mlpack {
 namespace tree {
 
+//! This is the structure the priority queue will use for traversal.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct CoverTreeQueueEntry
+{
+  //! The node this entry refers to.
+  CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+  //! The score of the node.
+  double score;
+  //! The index of the parent node.
+  size_t parent;
+  //! The base case evaluation (-1.0 if it has not been performed).
+  double baseCase;
+
+  //! Comparison operator.
+  bool operator<(const CoverTreeQueueEntry& other) const
+  {
+    return ((node->Scale() < other.node->Scale()) ||
+            ((node->Scale() == other.node->Scale()) && (score < other.score)));
+  }
+};
+
+
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 CoverTree<MetricType, RootPointPolicy, StatisticType>::
 SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
@@ -33,78 +56,141 @@
 {
   // This is a non-recursive implementation (which should be faster than a
   // recursive implementation).
-  std::queue<CoverTree<MetricType, RootPointPolicy, StatisticType>*> pointQueue;
-  std::queue<double> pointScores;
-  std::queue<double> baseCaseResults;
+  typedef CoverTreeQueueEntry<MetricType, RootPointPolicy, StatisticType>
+      QueueType;
+  std::priority_queue<QueueType> pointQueue;
 
-  pointQueue.push(&referenceNode);
-  pointScores.push(0.0); // Cannot be pruned.
+  // Unsorted list of leaves we have to look through.
+  std::queue<QueueType> leafQueue;
 
-  // Evaluate first base case.
-  baseCaseResults.push(rule.BaseCase(queryIndex, referenceNode.Point()));
+  QueueType first;
+  first.node = &referenceNode;
+  first.score = 0.0;
+  first.parent = (size_t() - 1); // Invalid index.
+  first.baseCase = -1.0;
+  pointQueue.push(first);
 
+  Log::Warn << "Beginning recursion for " << queryIndex << std::endl;
+
   while (!pointQueue.empty())
   {
-    CoverTree<MetricType, RootPointPolicy, StatisticType>* node =
-        pointQueue.front();
-    const double score = pointScores.front();
-    const double baseCase = baseCaseResults.front();
+    QueueType frame = pointQueue.top();
 
+    CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+    const double score = frame.score;
+    const size_t parent = frame.parent;
+    const size_t point = node->Point();
+    double baseCase = frame.baseCase;
+
+    Log::Debug << "Current point is " << node->Point() << " and scale "
+        << node->Scale() << ".\n";
+
     pointQueue.pop();
-    pointScores.pop();
-    baseCaseResults.pop();
 
     // First we (re)calculate the score of this node to find if we can prune it.
+    Log::Debug << "Before rescoring, score is " << score << " and base case of "
+        << "parent is " << baseCase << std::endl;
     double actualScore = rule.Rescore(queryIndex, *node, score);
 
+//    Log::Debug << "Actual score is " << actualScore << ".\n";
+
     if (actualScore == DBL_MAX)
     {
       // Prune this node.
+      Log::Debug << "Pruning after re-scoring (original score " << score << ")."
+          << std::endl;
       ++numPrunes;
       continue; // Skip to next in queue.
     }
 
-    // The base case is already evaluated.  So now we need to find out how to
-    // recurse into the children.
-    arma::vec baseCases(node->NumChildren());
-    arma::vec childScores(node->NumChildren());
-
-    // We already know the base case for the self child (that's the same as the
-    // base case for us).
-    baseCases[0] = baseCase;
-    if (node->Child(0).NumChildren() == 0)
-      childScores[0] = DBL_MAX; // Do not recurse into leaves (unnecessary).
+    // If we are a self-child, the base case has already been evaluated.
+    if (point != parent)
+    {
+      baseCase = rule.BaseCase(queryIndex, point);
+      Log::Debug << "Base case between " << queryIndex << " and " << point <<
+          " evaluates to " << baseCase << ".\n";
+    }
     else
-      childScores[0] = rule.Score(queryIndex, node->Child(0), baseCase);
-
-    // Fill the rest of the children.
-    for (size_t i = 1; i < node->NumChildren(); ++i)
     {
-      baseCases[i] = rule.BaseCase(queryIndex, node->Child(i).Point());
-      if (node->Child(i).NumChildren() == 0)
-        childScores[i] = DBL_MAX; // Do not recurse into leaves (unnecessary).
-      else
-        childScores[i] = rule.Score(queryIndex, node->Child(i), baseCases[i]);
+      Log::Debug << "Base case between " << queryIndex << " and " << point <<
+          " already known to be " << baseCase << ".\n";
     }
 
-    // Now sort by score.
-    arma::uvec order = arma::sort_index(childScores);
+    // Create the score for the children.
+    double childScore = rule.Score(queryIndex, *node, baseCase);
 
-    // Now add each to the queue.
-    for (size_t i = 0; i < order.n_elem; ++i)
+    // Now if the childScore is DBL_MAX we can prune all children.  In this
+    // recursion setup pruning is all or nothing for children.
+    if (childScore == DBL_MAX)
     {
-      // Ensure we haven't hit the limit yet.
-      const double childScore = childScores[order[i]];
-      if (childScore == DBL_MAX)
+      Log::Debug << "Pruning all children.\n";
+      numPrunes += node->NumChildren();
+    }
+    else
+    {
+      for (size_t i = 0; i < node->NumChildren(); ++i)
       {
-        numPrunes += (order.n_elem - i); // Prune the rest of the children.
-        break; // Go on to next point.
+        QueueType newFrame;
+        newFrame.node = &node->Child(i);
+        newFrame.score = childScore;
+        newFrame.baseCase = baseCase;
+        newFrame.parent = point;
+
+        // Put it into the regular priority queue if it has children.
+        if (newFrame.node->NumChildren() > 0)
+        {
+          Log::Debug << "Push back child " << i << ": point " <<
+              newFrame.node->Point() << ", scale " << newFrame.node->Scale()
+              << ".\n";
+          pointQueue.push(newFrame);
+        }
+        else if ((newFrame.node->NumChildren() == 0) && (i > 0))
+        {
+          // We don't add the self-leaf to the leaf queue (it can't possibly
+          // help).
+          Log::Debug << "Push back child " << i << ": point " <<
+              newFrame.node->Point() << ", scale " << newFrame.node->Scale()
+              << ".\n";
+          leafQueue.push(newFrame);
+        }
+        else
+        {
+          Log::Debug << "Prune self-leaf point " << point << ".\n";
+          ++numPrunes;
+        }
       }
+    }
+  }
 
-      pointQueue.push(&node->Child(order[i]));
-      pointScores.push(childScore);
-      baseCaseResults.push(baseCases[order[i]]);
+  // Now look through all the leaves.
+  while (!leafQueue.empty())
+  {
+    QueueType frame = leafQueue.front();
+
+    CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+    const double score = frame.score;
+    const size_t point = node->Point();
+
+    Log::Debug << "Inspecting leaf " << point << " with score " << score <<
+      "\n";
+
+    leafQueue.pop();
+
+    // First, recalculate the score of this node to find if we can prune it.
+    double actualScore = rule.Rescore(queryIndex, *node, score);
+
+    if (actualScore == DBL_MAX)
+    {
+      Log::Debug << "Pruned before base case.\n";
+      ++numPrunes;
+      continue;
     }
+
+    // There are no self-leaves in this queue, so the only thing left to do is
+    // evaluate the base case.
+    const double baseCase = rule.BaseCase(queryIndex, point);
+    Log::Debug << "Base case between " << queryIndex << " and " << point <<
+        " evaluates to " << baseCase << ".\n";
   }
 }
 




More information about the mlpack-svn mailing list