[mlpack-svn] r12706 - mlpack/trunk/src/mlpack/methods/maxip

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 17 12:55:16 EDT 2012


Author: rcurtin
Date: 2012-05-17 12:55:16 -0400 (Thu, 17 May 2012)
New Revision: 12706

Added:
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
Log:
Add rules for the MaxIP single-tree search.


Added: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	2012-05-17 16:55:16 UTC (rev 12706)
@@ -0,0 +1,52 @@
+/**
+ * @file max_ip_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Rules for the single or dual tree traversal for the maximum inner product
+ * search.
+ */
+#ifndef __MLPACK_METHODS_MAXIP_MAX_IP_RULES_HPP
+#define __MLPACK_METHODS_MAXIP_MAX_IP_RULES_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+namespace mlpack {
+namespace maxip {
+
+template<typename MetricType>
+class MaxIPRules
+{
+ public:
+  MaxIPRules(const arma::mat& referenceSet,
+             const arma::mat& querySet,
+             arma::Mat<size_t>& indices,
+             arma::mat& products);
+
+  void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+  bool CanPrune(const size_t queryIndex,
+                tree::CoverTree<MetricType>& referenceNode);
+
+ private:
+  const arma::mat& referenceSet;
+
+  const arma::mat& querySet;
+
+  arma::Mat<size_t>& indices;
+
+  arma::mat& products;
+
+  void InsertNeighbor(const size_t queryIndex,
+                      const size_t pos,
+                      const size_t neighbor,
+                      const double distance);
+};
+
+}; // namespace maxip
+}; // namespace mlpack
+
+// Include implementation.
+#include "max_ip_rules_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp	2012-05-17 16:55:16 UTC (rev 12706)
@@ -0,0 +1,103 @@
+/**
+ * @file max_ip_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of MaxIPRules for cover tree search.
+ */
+#ifndef __MLPACK_METHODS_MAXIP_MAX_IP_RULES_IMPL_HPP
+#define __MLPACK_METHODS_MAXIP_MAX_IP_RULES_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "max_ip_rules.hpp"
+
+namespace mlpack {
+namespace maxip {
+
+template<typename MetricType>
+MaxIPRules<MetricType>::MaxIPRules(const arma::mat& referenceSet,
+                                   const arma::mat& querySet,
+                                   arma::Mat<size_t>& indices,
+                                   arma::mat& products) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    indices(indices),
+    products(products)
+{ /* Nothing left to do. */ }
+
+template<typename MetricType>
+void MaxIPRules<MetricType>::BaseCase(const size_t queryIndex,
+                                      const size_t referenceIndex)
+{
+  const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+      referenceSet.col(referenceIndex));
+
+  if (eval > products(products.n_rows - 1, queryIndex))
+  {
+    size_t insertPosition;
+    for (insertPosition = 0; insertPosition < indices.n_rows; ++insertPosition)
+      if (eval > products(insertPosition, queryIndex))
+        break;
+
+    // We are guaranteed insertPosition is in the valid range.
+    InsertNeighbor(queryIndex, insertPosition, referenceIndex, eval);
+  }
+}
+
+template<typename MetricType>
+bool MaxIPRules<MetricType>::CanPrune(const size_t queryIndex,
+    tree::CoverTree<MetricType>& referenceNode)
+{
+  // The maximum possible inner product is given by
+  //   <q, p_0> + R_p || q ||
+  // and since we are using cover trees, p_0 is the point referred to by the
+  // node, and R_p will be the expansion constant to the power of the scale plus
+  // one.
+  double maxProduct = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+      referenceSet.col(referenceNode.Point()));
+
+  maxProduct += std::pow(referenceNode.ExpansionConstant(),
+      referenceNode.Scale() + 1) *
+      sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+      querySet.col(queryIndex)));
+
+  if (maxProduct > products(products.n_rows - 1, queryIndex))
+    return false;
+  else
+    return true;
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename MetricType>
+void MaxIPRules<MetricType>::InsertNeighbor(const size_t queryIndex,
+                                            const size_t pos,
+                                            const size_t neighbor,
+                                            const double distance)
+{
+  // We only memmove() if there is actually a need to shift something.
+  if (pos < (products.n_rows - 1))
+  {
+    int len = (products.n_rows - 1) - pos;
+    memmove(products.colptr(queryIndex) + (pos + 1),
+        products.colptr(queryIndex) + pos,
+        sizeof(double) * len);
+    memmove(indices.colptr(queryIndex) + (pos + 1),
+        indices.colptr(queryIndex) + pos,
+        sizeof(size_t) * len);
+  }
+
+  // Now put the new information in the right index.
+  products(pos, queryIndex) = distance;
+  indices(pos, queryIndex) = neighbor;
+}
+
+}; // namespace maxip
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list