[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching fastmks. (bccf3e0)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b

>---------------------------------------------------------------

commit bccf3e0d442ba554a2f0276822f102fcf2a2218a
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Fri Jul 22 01:39:22 2016 -0300

    Use a priority queue (heap) to store the list of candidates while searching fastmks.


>---------------------------------------------------------------

bccf3e0d442ba554a2f0276822f102fcf2a2218a
 src/mlpack/methods/fastmks/fastmks.hpp            |  30 ++++--
 src/mlpack/methods/fastmks/fastmks_impl.hpp       | 107 +++++++++-------------
 src/mlpack/methods/fastmks/fastmks_rules.hpp      |  59 ++++++++++--
 src/mlpack/methods/fastmks/fastmks_rules_impl.hpp | 104 ++++++++++++---------
 4 files changed, 174 insertions(+), 126 deletions(-)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index e849635..796f0db 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -12,6 +12,7 @@
 #include <mlpack/core/metrics/ip_metric.hpp>
 #include "fastmks_stat.hpp"
 #include <mlpack/core/tree/cover_tree.hpp>
+#include <queue>
 
 namespace mlpack {
 namespace fastmks /** Fast max-kernel search. */ {
@@ -250,13 +251,28 @@ class FastMKS
   //! The instantiated inner-product metric induced by the given kernel.
   metric::IPMetric<KernelType> metric;
 
-  //! Utility function.  Copied too many times from too many places.
-  void InsertNeighbor(arma::Mat<size_t>& indices,
-                      arma::mat& products,
-                      const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance);
+  //! Candidate point from the reference set.
+  struct Candidate
+  {
+    //! Kernel value calculated between a reference point and the query point.
+    double product;
+    //! Index of the reference point.
+    size_t index;
+    //! Trivial constructor.
+    Candidate(double p, size_t i) :
+        product(p),
+        index(i)
+    {};
+    //! Compare two candidates.
+    friend bool operator>(const Candidate& l, const Candidate& r)
+    {
+      return l.product > r.product;
+    };
+  };
+
+  //! Use a priority queue to represent the list of candidate points.
+  typedef std::priority_queue<Candidate, std::vector<Candidate>,
+      std::greater<Candidate>> CandidateList;
 };
 
 } // namespace fastmks
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index de85cdc..f0f321f 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -13,7 +13,6 @@
 #include "fastmks_rules.hpp"
 
 #include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <queue>
 
 namespace mlpack {
 namespace fastmks {
@@ -221,25 +220,31 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   // Naive implementation.
   if (naive)
   {
-    // Fill kernels.
-    kernels.fill(-DBL_MAX);
-
     // Simple double loop.  Stupid, slow, but a good benchmark.
     for (size_t q = 0; q < querySet.n_cols; ++q)
     {
+      const Candidate def(-DBL_MAX, size_t() - 1);
+      std::vector<Candidate> cList(k, def);
+      CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+
       for (size_t r = 0; r < referenceSet->n_cols; ++r)
       {
         const double eval = metric.Kernel().Evaluate(querySet.col(q),
                                                      referenceSet->col(r));
 
-        size_t insertPosition;
-        for (insertPosition = 0; insertPosition < indices.n_rows;
-            ++insertPosition)
-          if (eval > kernels(insertPosition, q))
-            break;
+        Candidate c(eval, r);
+        if (c > pqueue.top())
+        {
+          pqueue.pop();
+          pqueue.push(c);
+        }
+      }
 
-        if (insertPosition < indices.n_rows)
-          InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+      for (size_t j = 1; j <= k; j++)
+      {
+        indices(k - j, q) = pqueue.top().index;
+        kernels(k - j, q) = pqueue.top().product;
+        pqueue.pop();
       }
     }
 
@@ -251,13 +256,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   // Single-tree implementation.
   if (singleMode)
   {
-    // Fill kernels.
-    kernels.fill(-DBL_MAX);
-
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
     typedef FastMKSRules<KernelType, Tree> RuleType;
-    RuleType rules(*referenceSet, querySet, indices, kernels, metric.Kernel());
+    RuleType rules(*referenceSet, querySet, k, metric.Kernel());
 
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
@@ -267,6 +269,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     Log::Info << rules.BaseCases() << " base cases." << std::endl;
     Log::Info << rules.Scores() << " scores." << std::endl;
 
+    rules.GetResults(indices, kernels);
+
     Timer::Stop("computing_products");
     return;
   }
@@ -310,12 +314,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   // No remapping will be necessary because we are using the cover tree.
   indices.set_size(k, queryTree->Dataset().n_cols);
   kernels.set_size(k, queryTree->Dataset().n_cols);
-  kernels.fill(-DBL_MAX);
 
   Timer::Start("computing_products");
   typedef FastMKSRules<KernelType, Tree> RuleType;
-  RuleType rules(*referenceSet, queryTree->Dataset(), indices, kernels,
-      metric.Kernel());
+  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric.Kernel());
 
   typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
@@ -324,6 +326,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   Log::Info << rules.BaseCases() << " base cases." << std::endl;
   Log::Info << rules.Scores() << " scores." << std::endl;
 
+  rules.GetResults(indices, kernels);
+
   Timer::Stop("computing_products");
 }
 
@@ -341,7 +345,6 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   Timer::Start("computing_products");
   indices.set_size(k, referenceSet->n_cols);
   kernels.set_size(k, referenceSet->n_cols);
-  kernels.fill(-DBL_MAX);
 
   // Naive implementation.
   if (naive)
@@ -349,6 +352,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     // Simple double loop.  Stupid, slow, but a good benchmark.
     for (size_t q = 0; q < referenceSet->n_cols; ++q)
     {
+      const Candidate def(-DBL_MAX, size_t() - 1);
+      std::vector<Candidate> cList(k, def);
+      CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+
       for (size_t r = 0; r < referenceSet->n_cols; ++r)
       {
         if (q == r)
@@ -357,14 +364,19 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
         const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
                                                      referenceSet->col(r));
 
-        size_t insertPosition;
-        for (insertPosition = 0; insertPosition < indices.n_rows;
-            ++insertPosition)
-          if (eval > kernels(insertPosition, q))
-            break;
+        Candidate c(eval, r);
+        if (c > pqueue.top())
+        {
+          pqueue.pop();
+          pqueue.push(c);
+        }
+      }
 
-        if (insertPosition < indices.n_rows)
-          InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+      for (size_t j = 1; j <= k; j++)
+      {
+        indices(k - j, q) = pqueue.top().index;
+        kernels(k - j, q) = pqueue.top().product;
+        pqueue.pop();
       }
     }
 
@@ -379,8 +391,7 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
     typedef FastMKSRules<KernelType, Tree> RuleType;
-    RuleType rules(*referenceSet, *referenceSet, indices, kernels,
-        metric.Kernel());
+    RuleType rules(*referenceSet, *referenceSet, k, metric.Kernel());
 
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
@@ -395,6 +406,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     Log::Info << rules.BaseCases() << " base cases." << std::endl;
     Log::Info << rules.Scores() << " scores." << std::endl;
 
+    rules.GetResults(indices, kernels);
+
     Timer::Stop("computing_products");
     return;
   }
@@ -405,44 +418,6 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
   Search(referenceTree, k, indices, kernels);
 }
 
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename KernelType,
-         typename MatType,
-         template<typename TreeMetricType,
-                  typename TreeStatType,
-                  typename TreeMatType> class TreeType>
-void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
-    arma::Mat<size_t>& indices,
-    arma::mat& products,
-    const size_t queryIndex,
-    const size_t pos,
-    const size_t neighbor,
-    const double distance)
-{
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (products.n_rows - 1))
-  {
-    int len = (products.n_rows - 1) - pos;
-    memmove(products.colptr(queryIndex) + (pos + 1),
-        products.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(indices.colptr(queryIndex) + (pos + 1),
-        indices.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
-  }
-
-  // Now put the new information in the right index.
-  products(pos, queryIndex) = distance;
-  indices(pos, queryIndex) = neighbor;
-}
-
 //! Serialize the model.
 template<typename KernelType,
          typename MatType,
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 0f4ad34..13be5f5 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -10,6 +10,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
 #include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
 
 namespace mlpack {
 namespace fastmks {
@@ -23,10 +24,17 @@ class FastMKSRules
  public:
   FastMKSRules(const typename TreeType::Mat& referenceSet,
                const typename TreeType::Mat& querySet,
-               arma::Mat<size_t>& indices,
-               arma::mat& products,
+               const size_t k,
                KernelType& kernel);
 
+  /**
+   * Store the list of candidates for each query point in the given matrices.
+   *
+   * @param indices Matrix storing lists of candidate points for each query point.
+   * @param products Matrix storing kernel value for each candidate.
+   */
+  void GetResults(arma::Mat<size_t>& indices, arma::mat& products);
+
   //! Compute the base case (kernel value) between two points.
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
@@ -101,10 +109,36 @@ class FastMKSRules
   //! The query dataset.
   const typename TreeType::Mat& querySet;
 
-  //! The indices of the maximum kernel results.
-  arma::Mat<size_t>& indices;
-  //! The maximum kernels.
-  arma::mat& products;
+  //! Candidate point from the reference set.
+  struct Candidate
+  {
+    //! Kernel value calculated between a reference point and the query point.
+    double product;
+    //! Index of the reference point.
+    size_t index;
+    //! Trivial constructor.
+    Candidate(double p, size_t i) :
+        product(p),
+        index(i)
+    {};
+    //! Compare two candidates.
+    friend bool operator>(const Candidate& l, const Candidate& r)
+    {
+      return l.product > r.product;
+    };
+  };
+
+  //! Use a min heap to represent the list of candidate points.
+  //! We will use a vector and the std functions: push_heap() pop_heap().
+  //! We can not use a priority queue because we need to iterate over all the
+  //! candidates and std::priority_queue doesn't provide that interface.
+  typedef std::vector<Candidate> CandidateList;
+
+  //! Set of candidates for each point.
+  std::vector<CandidateList> candidates;
+
+  //! Number of points to search for.
+  const size_t k;
 
   //! Cached query set self-kernels (|| q || for each q).
   arma::vec queryKernels;
@@ -124,11 +158,16 @@ class FastMKSRules
   //! Calculate the bound for a given query node.
   double CalculateBound(TreeType& queryNode) const;
 
-  //! Utility function to insert neighbor into list of results.
+  /**
+   * Helper function to insert a point into the list of candidate points.
+   *
+   * @param queryIndex Index of point whose neighbors we are inserting into.
+   * @param index Index of reference point which is being inserted.
+   * @param product Kernel value for given candidate.
+   */
   void InsertNeighbor(const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance);
+                      const size_t index,
+                      const double product);
 
   //! For benchmarking.
   size_t baseCases;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index 27abacf..a5cf681 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -9,6 +9,7 @@
 
 // In case it hasn't already been included.
 #include "fastmks_rules.hpp"
+#include <algorithm>
 
 namespace mlpack {
 namespace fastmks {
@@ -17,13 +18,11 @@ template<typename KernelType, typename TreeType>
 FastMKSRules<KernelType, TreeType>::FastMKSRules(
     const typename TreeType::Mat& referenceSet,
     const typename TreeType::Mat& querySet,
-    arma::Mat<size_t>& indices,
-    arma::mat& products,
+    const size_t k,
     KernelType& kernel) :
     referenceSet(referenceSet),
     querySet(querySet),
-    indices(indices),
-    products(products),
+    k(k),
     kernel(kernel),
     lastQueryIndex(-1),
     lastReferenceIndex(-1),
@@ -46,6 +45,41 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(
   // dereference null pointers.
   traversalInfo.LastQueryNode() = (TreeType*) this;
   traversalInfo.LastReferenceNode() = (TreeType*) this;
+
+  // Let's build the list of candidate points for each query point.
+  // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1)
+  // The list of candidates will be updated when visiting new points with the
+  // BaseCase() method.
+  const Candidate def(-DBL_MAX, size_t() - 1);
+
+  CandidateList cList(k, def);
+  std::vector<CandidateList> tmp(querySet.n_cols, cList);
+  candidates.swap(tmp);
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKSRules<KernelType, TreeType>::GetResults(
+    arma::Mat<size_t>& indices,
+    arma::mat& products)
+{
+  indices.set_size(k, querySet.n_cols);
+  products.set_size(k, querySet.n_cols);
+
+  for (size_t i = 0; i < querySet.n_cols; i++)
+  {
+    CandidateList& pqueue = candidates[i];
+    std::greater<Candidate> greater;
+    typedef typename CandidateList::iterator Iterator;
+
+    for (Iterator end = pqueue.end(); end != pqueue.begin(); --end)
+      std::pop_heap(pqueue.begin(), end, greater);
+
+    for (size_t j = 0; j < k; j++)
+    {
+      indices(j, i) = pqueue[j].index;
+      products(j, i) = pqueue[j].product;
+    }
+  }
 }
 
 template<typename KernelType, typename TreeType>
@@ -83,16 +117,7 @@ double FastMKSRules<KernelType, TreeType>::BaseCase(
   if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
     return kernelEval;
 
-  // If this is a better candidate, insert it into the list.
-  if (kernelEval < products(products.n_rows - 1, queryIndex))
-    return kernelEval;
-
-  size_t insertPosition = 0;
-  for ( ; insertPosition < products.n_rows; ++insertPosition)
-    if (kernelEval >= products(insertPosition, queryIndex))
-      break;
-
-  InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
+  InsertNeighbor(queryIndex, referenceIndex, kernelEval);
 
   return kernelEval;
 }
@@ -102,7 +127,7 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
                                                  TreeType& referenceNode)
 {
   // Compare with the current best.
-  const double bestKernel = products(products.n_rows - 1, queryIndex);
+  const double bestKernel = candidates[queryIndex].front().product;
 
   // See if we can perform a parent-child prune.
   const double furthestDist = referenceNode.FurthestDescendantDistance();
@@ -385,7 +410,7 @@ double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
                                                    TreeType& /*referenceNode*/,
                                                    const double oldScore) const
 {
-  const double bestKernel = products(products.n_rows - 1, queryIndex);
+  const double bestKernel = candidates[queryIndex].front().product;
 
   return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
 }
@@ -432,10 +457,11 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
     const size_t point = queryNode.Point(i);
-    if (products(products.n_rows - 1, point) < worstPointKernel)
-      worstPointKernel = products(products.n_rows - 1, point);
+    const CandidateList& candidatesPoints = candidates[point];
+    if (candidatesPoints.front().product < worstPointKernel)
+      worstPointKernel = candidatesPoints.front().product;
 
-    if (products(products.n_rows - 1, point) == -DBL_MAX)
+    if (candidatesPoints.front().product == -DBL_MAX)
       continue; // Avoid underflow.
 
     // This should be (queryDescendantDistance + centroidDistance) for any tree
@@ -450,10 +476,10 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
     // where p_j^*(p_q) is the j'th kernel candidate for query point p_q and
     // k_j^*(p_q) is K(p_q, p_j^*(p_q)).
     double worstPointCandidateKernel = DBL_MAX;
-    for (size_t j = 0; j < products.n_rows; ++j)
+    for (size_t j = 0; j < candidatesPoints.size(); ++j)
     {
-      const double candidateKernel = products(j, point) -
-          queryDescendantDistance * referenceKernels[indices(j, point)];
+      const double candidateKernel = candidatesPoints[j].product -
+          queryDescendantDistance * referenceKernels[candidatesPoints[j].index];
       if (candidateKernel < worstPointCandidateKernel)
         worstPointCandidateKernel = candidateKernel;
     }
@@ -488,34 +514,26 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
 }
 
 /**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
  *
  * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
+ * @param index Index of reference point which is being inserted.
+ * @param product Kernel value for given candidate.
  */
 template<typename KernelType, typename TreeType>
-void FastMKSRules<KernelType, TreeType>::InsertNeighbor(const size_t queryIndex,
-                                                        const size_t pos,
-                                                        const size_t neighbor,
-                                                        const double distance)
+inline void FastMKSRules<KernelType, TreeType>::InsertNeighbor(
+    const size_t queryIndex,
+    const size_t index,
+    const double product)
 {
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (products.n_rows - 1))
+  Candidate c(product, index);
+  CandidateList& pqueue = candidates[queryIndex];
+  if (c > pqueue.front())
   {
-    int len = (products.n_rows - 1) - pos;
-    memmove(products.colptr(queryIndex) + (pos + 1),
-        products.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(indices.colptr(queryIndex) + (pos + 1),
-        indices.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
+    std::pop_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+    pqueue.back() = c;
+    std::push_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
   }
-
-  // Now put the new information in the right index.
-  products(pos, queryIndex) = distance;
-  indices(pos, queryIndex) = neighbor;
 }
 
 } // namespace fastmks




More information about the mlpack-git mailing list