[mlpack-svn] r12718 - mlpack/trunk/src/mlpack/methods/maxip

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 17 23:25:21 EDT 2012


Author: rcurtin
Date: 2012-05-17 23:25:21 -0400 (Thu, 17 May 2012)
New Revision: 12718

Modified:
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
Log:
Avoid double-evaluation of base case.


Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	2012-05-17 19:09:24 UTC (rev 12717)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	2012-05-18 03:25:21 UTC (rev 12718)
@@ -28,19 +28,7 @@
 void MaxIPRules<MetricType>::BaseCase(const size_t queryIndex,
                                       const size_t referenceIndex)
 {
-  const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
-      referenceSet.col(referenceIndex));
-
-  if (eval > products(products.n_rows - 1, queryIndex))
-  {
-    size_t insertPosition;
-    for (insertPosition = 0; insertPosition < indices.n_rows; ++insertPosition)
-      if (eval > products(insertPosition, queryIndex))
-        break;
-
-    // We are guaranteed insertPosition is in the valid range.
-    InsertNeighbor(queryIndex, insertPosition, referenceIndex, eval);
-  }
+  // Should be optimized out...
 }
 
 template<typename MetricType>
@@ -52,10 +40,22 @@
   // and since we are using cover trees, p_0 is the point referred to by the
   // node, and R_p will be the expansion constant to the power of the scale plus
   // one.
-  double maxProduct = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+  const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
       referenceSet.col(referenceNode.Point()));
 
-  maxProduct += std::pow(referenceNode.ExpansionConstant(),
+  // See if base case can be added.
+  if (eval > products(products.n_rows - 1, queryIndex))
+  {
+    size_t insertPosition;
+    for (insertPosition = 0; insertPosition < indices.n_rows; ++insertPosition)
+      if (eval > products(insertPosition, queryIndex))
+        break;
+
+    // We are guaranteed insertPosition is in the valid range.
+    InsertNeighbor(queryIndex, insertPosition, referenceNode.Point(), eval);
+  }
+
+  double maxProduct = eval + std::pow(referenceNode.ExpansionConstant(),
       referenceNode.Scale() + 1) *
       sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
       querySet.col(queryIndex)));




More information about the mlpack-svn mailing list