[mlpack-svn] r12723 - mlpack/trunk/src/mlpack/methods/maxip

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri May 18 18:00:36 EDT 2012


Author: rcurtin
Date: 2012-05-18 18:00:35 -0400 (Fri, 18 May 2012)
New Revision: 12723

Modified:
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
Log:
Reimplement single-tree search.  It's faster now.


Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-05-18 22:00:35 UTC (rev 12723)
@@ -10,7 +10,7 @@
 // In case it hasn't yet been included.
 #include "max_ip.hpp"
 
-#include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
+#include <mlpack/core/tree/traversers/single_cover_tree_traverser.hpp>
 #include "max_ip_rules.hpp"
 
 namespace mlpack {
@@ -112,18 +112,91 @@
   // Single-tree implementation.
   if (single)
   {
-    MaxIPRules<IPMetric<KernelType> > rules(referenceSet, querySet, indices,
-        products);
+    // Calculate number of pruned nodes.
+    size_t numPrunes = 0;
 
-    tree::SingleTreeBreadthFirstTraverser<
-        tree::CoverTree<IPMetric<KernelType> >,
-        MaxIPRules<IPMetric<KernelType> > > traverser(rules);
+    // Precalculate query products ( || q || for all q).
+    arma::vec queryProducts(querySet.n_cols);
+    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+      queryProducts[queryIndex] = KernelType::Evaluate(
+          querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex));
 
-    for (size_t i = 0; i < querySet.n_cols; ++i)
-      traverser.Traverse(i, *referenceTree);
+    // Screw the CoverTreeTraverser, we'll implement it by hand.
+    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+    {
+      std::queue<tree::CoverTree<IPMetric<KernelType> >*> pointQueue;
+      std::queue<size_t> parentQueue;
+      std::queue<double> parentEvalQueue;
+      pointQueue.push(referenceTree);
+      parentQueue.push(size_t() - 1); // Has no parent.
+      parentEvalQueue.push(0); // No possible parent evaluation.
 
-    Log::Info << "Pruned " << traverser.NumPrunes() << " nodes." << std::endl;
+      tree::CoverTree<IPMetric<KernelType> >* referenceNode;
+      size_t currentParent;
+      double currentParentEval;
+      double eval; // Kernel evaluation.
 
+      while (!pointQueue.empty())
+      {
+        // Get the information for this node.
+        referenceNode = pointQueue.front();
+        currentParent = parentQueue.front();
+        currentParentEval = parentEvalQueue.front();
+
+        pointQueue.pop();
+        parentQueue.pop();
+        parentEvalQueue.pop();
+
+        // See if this has the same parent.
+        if (referenceNode->Point() == currentParent)
+        {
+          // We don't have to evaluate the kernel again.
+          eval = currentParentEval;
+        }
+        else
+        {
+          // Evaluate the kernel.  Then see if it is a result to keep.
+          eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
+              referenceSet.unsafe_col(referenceNode->Point()));
+
+          // Is the result good enough to be saved?
+          if (eval > products(products.n_rows - 1, queryIndex))
+          {
+            // Figure out where to insert.
+            size_t insertPosition = 0;
+            for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
+              if (eval > products(insertPosition, queryIndex))
+                break;
+
+            // We are guaranteed that insertPosition is valid.
+            InsertNeighbor(indices, products, queryIndex, insertPosition,
+                referenceNode->Point(), eval);
+          }
+        }
+
+        // Now discover if we can prune this node or not.
+        double maxProduct = eval + std::pow(referenceNode->ExpansionConstant(),
+            referenceNode->Scale() + 1) * queryProducts[queryIndex];
+
+        if (maxProduct > products(products.n_rows - 1, queryIndex))
+        {
+          // We can't prune.  So add our children.
+          for (size_t i = 0; i < referenceNode->NumChildren(); ++i)
+          {
+            pointQueue.push(&(referenceNode->Child(i)));
+            parentQueue.push(referenceNode->Point());
+            parentEvalQueue.push(eval);
+          }
+        }
+        else
+        {
+          numPrunes++;
+        }
+      }
+    }
+
+    Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
     Timer::Stop("computing_products");
     return;
   }

Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	2012-05-18 22:00:35 UTC (rev 12723)
@@ -26,7 +26,8 @@
   void BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
   bool CanPrune(const size_t queryIndex,
-                tree::CoverTree<MetricType>& referenceNode);
+                tree::CoverTree<MetricType>& referenceNode,
+                const size_t parentIndex);
 
  private:
   const arma::mat& referenceSet;
@@ -37,6 +38,8 @@
 
   arma::mat& products;
 
+  arma::vec queryKernels; // || q || for each q.
+
   void InsertNeighbor(const size_t queryIndex,
                       const size_t pos,
                       const size_t neighbor,

Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	2012-05-18 22:00:35 UTC (rev 12723)
@@ -22,27 +22,26 @@
     querySet(querySet),
     indices(indices),
     products(products)
-{ /* Nothing left to do. */ }
-
-template<typename MetricType>
-void MaxIPRules<MetricType>::BaseCase(const size_t queryIndex,
-                                      const size_t referenceIndex)
 {
-  // Should be optimized out...
+  // Precompute each self-kernel.
+//  queryKernels.set_size(querySet.n_cols);
+//  for (size_t i = 0; i < querySet.n_cols; ++i)
+//    queryKernels[i] = sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(i),
+//        querySet.unsafe_col(i)));
 }
 
 template<typename MetricType>
 bool MaxIPRules<MetricType>::CanPrune(const size_t queryIndex,
-    tree::CoverTree<MetricType>& referenceNode)
+    tree::CoverTree<MetricType>& referenceNode,
+    const size_t parentIndex)
 {
   // The maximum possible inner product is given by
   //   <q, p_0> + R_p || q ||
   // and since we are using cover trees, p_0 is the point referred to by the
   // node, and R_p will be the expansion constant to the power of the scale plus
   // one.
-  const double eval = MetricType::Kernel::Evaluate(
-      querySet.unsafe_col(queryIndex),
-      referenceSet.unsafe_col(referenceNode.Point()));
+  const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+      referenceSet.col(referenceNode.Point()));
 
   // See if base case can be added.
   if (eval > products(products.n_rows - 1, queryIndex))
@@ -58,8 +57,8 @@
 
   double maxProduct = eval + std::pow(referenceNode.ExpansionConstant(),
       referenceNode.Scale() + 1) *
-      sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(queryIndex),
-      querySet.unsafe_col(queryIndex)));
+sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+querySet.col(queryIndex)));
 
   if (maxProduct > products(products.n_rows - 1, queryIndex))
     return false;




More information about the mlpack-svn mailing list