[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. (fdeeb88)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b

>---------------------------------------------------------------

commit fdeeb88bea641cce93f83556b0b1b146f55c1383
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 21 22:14:12 2016 -0300

    Use a priority queue (heap) to store the list of candidates while searching.


>---------------------------------------------------------------

fdeeb88bea641cce93f83556b0b1b146f55c1383
 src/mlpack/methods/lsh/lsh_search.hpp      | 48 ++++++++-------
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 98 ++++++++++++++----------------
 2 files changed, 72 insertions(+), 74 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 4e6cc97..45284ba 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -41,6 +41,7 @@
 #include <mlpack/core.hpp>
 #include <vector>
 #include <string>
+#include <queue>
 
 #include <mlpack/core/metrics/lmetric.hpp>
 #include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
@@ -298,11 +299,13 @@ class LSHSearch
    * @param queryIndex The index of the query in question
    * @param referenceIndices The vector of indices of candidate neighbors for
    *    the query.
+   * @param k Number of neighbors to search for.
    * @param neighbors Matrix holding output neighbors.
    * @param distances Matrix holding output distances.
    */
   void BaseCase(const size_t queryIndex,
                 const arma::uvec& referenceIndices,
+                const size_t k,
                 arma::Mat<size_t>& neighbors,
                 arma::mat& distances) const;
 
@@ -315,38 +318,19 @@ class LSHSearch
    * @param queryIndex The index of the query in question
    * @param referenceIndices The vector of indices of candidate neighbors for
    *    the query.
+   * @param k Number of neighbors to search for.
    * @param querySet Set of query points.
    * @param neighbors Matrix holding output neighbors.
    * @param distances Matrix holding output distances.
    */
   void BaseCase(const size_t queryIndex,
                 const arma::uvec& referenceIndices,
+                const size_t k,
                 const arma::mat& querySet,
                 arma::Mat<size_t>& neighbors,
                 arma::mat& distances) const;
 
   /**
-   * This is a helper function that efficiently inserts better neighbor
-   * candidates into an existing set of neighbor candidates. This function is
-   * only called by the 'BaseCase' function.
-   *
-   * @param distances Matrix holding output distances.
-   * @param neighbors Matrix holding output neighbors.
-   * @param queryIndex This is the index of the query being processed currently
-   * @param pos The position of the neighbor candidate in the current list of
-   *    neighbor candidates.
-   * @param neighbor The neighbor candidate that is being inserted into the list
-   *    of the best 'k' candidates for the query in question.
-   * @param distance The distance of the query to the neighbor candidate.
-   */
-  void InsertNeighbor(arma::mat& distances,
-                      arma::Mat<size_t>& neighbors,
-                      const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance) const;
-
-  /**
    * This function implements the core idea behind Multiprobe LSH. It is called
    * by ReturnIndicesFromTables when T > 0. Given a query's code and its
    * projection location, GetAdditionalProbingBins will calculate the T most
@@ -444,6 +428,28 @@ class LSHSearch
   //! The number of distance evaluations.
   size_t distanceEvaluations;
 
+  //! Candidate represents a possible candidate neighbor (from the reference
+  // set).
+  struct Candidate
+  {
+    //! Distance between the reference point and the query point.
+    double dist;
+    //! Index of the reference point.
+    size_t index;
+    //! Trivial constructor.
+    Candidate(double d, size_t i) :
+        dist(d),
+        index(i)
+    {};
+    //! Compare the distance of two candidates.
+    friend bool operator<(const Candidate& l, const Candidate& r)
+    {
+      return !SortPolicy::IsBetter(r.dist, l.dist);
+    };
+  };
+
+  //! Use a priority queue to represent the list of candidate neighbors.
+  typedef std::priority_queue<Candidate> CandidateList;
 }; // class LSHSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index bcb5795..d0b53ae 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -262,40 +262,22 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
             << std::endl;
 }
 
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
-                                           arma::Mat<size_t>& neighbors,
-                                           const size_t queryIndex,
-                                           const size_t pos,
-                                           const size_t neighbor,
-                                           const double distance) const
-{
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (distances.n_rows - 1))
-  {
-    const size_t len = (distances.n_rows - 1) - pos;
-    memmove(distances.colptr(queryIndex) + (pos + 1),
-        distances.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(neighbors.colptr(queryIndex) + (pos + 1),
-        neighbors.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
-  }
-
-  // Now put the new information in the right index.
-  distances(pos, queryIndex) = distance;
-  neighbors(pos, queryIndex) = neighbor;
-}
-
 // Base case where the query set is the reference set.  (So, we can't return
 // ourselves as the nearest neighbor.)
 template<typename SortPolicy>
 inline force_inline
 void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
                                      const arma::uvec& referenceIndices,
+                                     const size_t k,
                                      arma::Mat<size_t>& neighbors,
                                      arma::mat& distances) const
 {
+  // Let's build the list of candidate neighbors for the given query point.
+  // It will be initialized with k candidates:
+  // (WorstDistance, referenceSet->n_cols)
+  const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+  std::vector<Candidate> vect(k, def);
+  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
 
   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
   {
@@ -308,17 +290,20 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         referenceSet->unsafe_col(queryIndex),
         referenceSet->unsafe_col(referenceIndex));
 
-    // If this distance is better than any of the current candidates, the
-    // SortDistance() function will give us the position to insert it into.
-    arma::vec queryDist = distances.unsafe_col(queryIndex);
-    arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
-    size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
-        distance);
-
-    // SortDistance() returns (size_t() - 1) if we shouldn't add it.
-    if (insertPosition != (size_t() - 1))
-      InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
-          referenceIndex, distance);
+    Candidate c(distance, referenceIndex);
+    // If this distance is better than the worst candidate, let's insert it.
+    if (c < pqueue.top())
+    {
+      pqueue.pop();
+      pqueue.push(c);
+    }
+  }
+
+  for (size_t j = 1; j <= k; j++)
+  {
+    neighbors(k - j, queryIndex) = pqueue.top().index;
+    distances(k - j, queryIndex) = pqueue.top().dist;
+    pqueue.pop();
   }
 }
 
@@ -327,10 +312,18 @@ template<typename SortPolicy>
 inline force_inline
 void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
                                      const arma::uvec& referenceIndices,
+                                     const size_t k,
                                      const arma::mat& querySet,
                                      arma::Mat<size_t>& neighbors,
                                      arma::mat& distances) const
 {
+  // Let's build the list of candidate neighbors for the given query point.
+  // It will be initialized with k candidates:
+  // (WorstDistance, referenceSet->n_cols)
+  const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+  std::vector<Candidate> vect(k, def);
+  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
   {
     const size_t referenceIndex = referenceIndices[j];
@@ -338,20 +331,23 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         querySet.unsafe_col(queryIndex),
         referenceSet->unsafe_col(referenceIndex));
 
-    // If this distance is better than any of the current candidates, the
-    // SortDistance() function will give us the position to insert it into.
-    arma::vec queryDist = distances.unsafe_col(queryIndex);
-    arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
-    size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
-        distance);
-
-    // SortDistance() returns (size_t() - 1) if we shouldn't add it.
-    if (insertPosition != (size_t() - 1))
-      InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
-          referenceIndex, distance);
+    Candidate c(distance, referenceIndex);
+    // If this distance is better than the worst candidate, let's insert it.
+    if (c < pqueue.top())
+    {
+      pqueue.pop();
+      pqueue.push(c);
+    }
+  }
 
+  for (size_t j = 1; j <= k; j++)
+  {
+    neighbors(k - j, queryIndex) = pqueue.top().index;
+    distances(k - j, queryIndex) = pqueue.top().dist;
+    pqueue.pop();
   }
 }
+
 template<typename SortPolicy>
 inline force_inline
 double LSHSearch<SortPolicy>::PerturbationScore(
@@ -794,8 +790,6 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
   // Set the size of the neighbor and distance matrices.
   resultingNeighbors.set_size(k, querySet.n_cols);
   distances.set_size(k, querySet.n_cols);
-  distances.fill(SortPolicy::WorstDistance());
-  resultingNeighbors.fill(referenceSet->n_cols);
 
   // If the user asked for 0 nearest neighbors... uh... we're done.
   if (k == 0)
@@ -854,7 +848,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
 
     // Sequentially go through all the candidates and save the best 'k'
     // candidates.
-    BaseCase(i, refIndices, querySet, resultingNeighbors, distances);
+    BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances);
   }
 
   Timer::Stop("computing_neighbors");
@@ -877,8 +871,6 @@ Search(const size_t k,
   // This is monochromatic search; the query set is the reference set.
   resultingNeighbors.set_size(k, referenceSet->n_cols);
   distances.set_size(k, referenceSet->n_cols);
-  distances.fill(SortPolicy::WorstDistance());
-  resultingNeighbors.fill(referenceSet->n_cols);
 
   // If the user requested more than the available number of additional probing
   // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
@@ -933,7 +925,7 @@ Search(const size_t k,
 
     // Sequentially go through all the candidates and save the best 'k'
     // candidates.
-    BaseCase(i, refIndices, resultingNeighbors, distances);
+    BaseCase(i, refIndices, k, resultingNeighbors, distances);
   }
 
   Timer::Stop("computing_neighbors");




More information about the mlpack-git mailing list