[mlpack-svn] r12746 - mlpack/trunk/src/mlpack/methods/maxip

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon May 21 14:51:22 EDT 2012


Author: rcurtin
Date: 2012-05-21 14:51:21 -0400 (Mon, 21 May 2012)
New Revision: 12746

Modified:
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
Log:
Revamp breadth-first descent to consider tree levels so it is the proper
breadth-first descent.


Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-05-21 18:39:45 UTC (rev 12745)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-05-21 18:51:21 UTC (rev 12746)
@@ -19,10 +19,27 @@
 struct SearchFrame
 {
   TreeType* node;
-  size_t parent;
-  double parentEval;
+  double eval;
 };
 
+template<typename TreeType>
+class SearchFrameCompare
+{
+ public:
+  bool operator()(const SearchFrame<TreeType>& lhs,
+                  const SearchFrame<TreeType>& rhs)
+  {
+    // Compare scale.
+    if (lhs.node->Scale() != rhs.node->Scale())
+      return (lhs.node->Scale() < rhs.node->Scale());
+    else
+    {
+      // Now we have to compare by evaluation.
+      return (lhs.eval < rhs.eval);
+    }
+  }
+};
+
 template<typename KernelType>
 MaxIP<KernelType>::MaxIP(const arma::mat& referenceSet,
                          bool single,
@@ -117,7 +134,7 @@
     }
 
     Timer::Stop("computing_products");
-    
+
     Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
     return;
   }
@@ -138,88 +155,116 @@
     // Screw the CoverTreeTraverser, we'll implement it by hand.
     for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
     {
-      std::queue<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >
+      // Use an array of priority queues?
+      std::priority_queue<
+          SearchFrame<tree::CoverTree<IPMetric<KernelType> > >,
+          std::vector<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >,
+          SearchFrameCompare<tree::CoverTree<IPMetric<KernelType> > > >
           frameQueue;
 
+      // Add initial frame.
       SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
       nextFrame.node = referenceTree;
-      nextFrame.parent = size_t() - 1;
-      nextFrame.parentEval = 0;
+      nextFrame.eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
+          referenceSet.unsafe_col(referenceTree->Point()));
 
+      // The initial evaluation will be the best so far.
+      indices(0, queryIndex) = referenceTree->Point();
+      products(0, queryIndex) = nextFrame.eval;
+
       frameQueue.push(nextFrame);
 
       tree::CoverTree<IPMetric<KernelType> >* referenceNode;
-      size_t currentParent;
-      double currentParentEval;
-      double eval; // Kernel evaluation.
+      double eval;
+      double maxProduct;
 
       while (!frameQueue.empty())
       {
         // Get the information for this node.
         const SearchFrame<tree::CoverTree<IPMetric<KernelType> > >& frame =
-            frameQueue.front();
+            frameQueue.top();
+
         referenceNode = frame.node;
-        currentParent = frame.parent;
-        currentParentEval = frame.parentEval;
+        eval = frame.eval;
 
-        frameQueue.pop();
-
-        // See if this has the same parent.
-        if (referenceNode->Point() == currentParent)
+        // Loop through the children, seeing if we can prune them; if not, add
+        // them to the queue.  The self-child is different -- it has the same
+        // parent (and therefore the same kernel evaluation).
+        if (referenceNode->NumChildren() > 0)
         {
-          // We don't have to evaluate the kernel again.
-          eval = currentParentEval;
-        }
-        else
-        {
-          // Evaluate the kernel.  Then see if it is a result to keep.
-          eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
-              referenceSet.unsafe_col(referenceNode->Point()));
-          ++kernelEvaluations;
+          SearchFrame<tree::CoverTree<IPMetric<KernelType> > > childFrame;
 
-          // Is the result good enough to be saved?
-          if (eval > products(products.n_rows - 1, queryIndex))
+          // We must handle the self-child differently, to avoid adding it to
+          // the results twice.
+          childFrame.node = &(referenceNode->Child(0));
+          childFrame.eval = eval;
+
+          maxProduct = eval + std::pow(childFrame.node->ExpansionConstant(),
+              childFrame.node->Scale() + 1) * queryProducts[queryIndex];
+
+          // Add self-child if we can't prune it.
+          if (maxProduct > products(products.n_rows - 1, queryIndex))
           {
-            // Figure out where to insert.
-            size_t insertPosition = 0;
-            for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
-              if (eval > products(insertPosition, queryIndex))
-                break;
-
-            // We are guaranteed that insertPosition is valid.
-            InsertNeighbor(indices, products, queryIndex, insertPosition,
-                referenceNode->Point(), eval);
+            // But only if it has children of its own.
+            if (childFrame.node->NumChildren() > 0)
+              frameQueue.push(childFrame);
           }
-        }
+          else
+            ++numPrunes;
 
-        // Now discover if we can prune this node or not.
-        double maxProduct = eval + std::pow(referenceNode->ExpansionConstant(),
-            referenceNode->Scale() + 1) * queryProducts[queryIndex];
-
-        if (maxProduct > products(products.n_rows - 1, queryIndex))
-        {
-          // We can't prune.  So add our children.
-          for (size_t i = 0; i < referenceNode->NumChildren(); ++i)
+          for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
           {
-            SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
+            // Evaluate child.
+            childFrame.node = &(referenceNode->Child(i));
+            childFrame.eval = KernelType::Evaluate(
+                querySet.unsafe_col(queryIndex),
+                referenceSet.unsafe_col(referenceNode->Child(i).Point()));
 
-            nextFrame.node = &(referenceNode->Child(i));
-            nextFrame.parent = referenceNode->Point();
-            nextFrame.parentEval = eval;
+            // Can we prune it?  If we can, we can avoid putting it in the queue
+            // (saves time).
+            double maxProduct = childFrame.eval +
+                std::pow(childFrame.node->ExpansionConstant(),
+                childFrame.node->Scale() + 1) * queryProducts[queryIndex];
 
-            frameQueue.push(nextFrame);
+            if (maxProduct > products(products.n_rows - 1, queryIndex))
+            {
+              // Good enough to recurse into.  While we're at it, check the
+              // actual evaluation and see if it's an improvement.
+              if (childFrame.eval > products(products.n_rows - 1, queryIndex))
+              {
+                // This is a better result.  Find out where to insert.
+                size_t insertPosition = 0;
+                for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
+                  if (childFrame.eval > products(insertPosition, queryIndex))
+                    break;
+
+                // Insert into the correct position; we are guaranteed that
+                // insertPosition is valid.
+                InsertNeighbor(indices, products, queryIndex, insertPosition,
+                    childFrame.node->Point(), childFrame.eval);
+              }
+
+              // Now add this to the queue (if it has any children which may
+              // need to be recursed into).
+              if (childFrame.node->NumChildren() > 0)
+                frameQueue.push(childFrame);
+            }
+            else
+            {
+              ++numPrunes;
+            }
           }
         }
-        else
-        {
-          numPrunes++;
-        }
+
+        frameQueue.pop();
       }
     }
 
     Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-    Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
-    Log::Info << "Distance evaluations: " << distanceEvaluations << "." << std::endl;
+    Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
+        << std::endl;
+    Log::Info << "Distance evaluations: " << distanceEvaluations << "."
+        << std::endl;
 
     Timer::Stop("computing_products");
     return;




More information about the mlpack-svn mailing list