[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. (fdeeb88)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b
>---------------------------------------------------------------
commit fdeeb88bea641cce93f83556b0b1b146f55c1383
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 21 22:14:12 2016 -0300
Use a priority queue (heap) to store the list of candidates while searching.
>---------------------------------------------------------------
fdeeb88bea641cce93f83556b0b1b146f55c1383
src/mlpack/methods/lsh/lsh_search.hpp | 48 ++++++++-------
src/mlpack/methods/lsh/lsh_search_impl.hpp | 98 ++++++++++++++----------------
2 files changed, 72 insertions(+), 74 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 4e6cc97..45284ba 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -41,6 +41,7 @@
#include <mlpack/core.hpp>
#include <vector>
#include <string>
+#include <queue>
#include <mlpack/core/metrics/lmetric.hpp>
#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
@@ -298,11 +299,13 @@ class LSHSearch
* @param queryIndex The index of the query in question
* @param referenceIndices The vector of indices of candidate neighbors for
* the query.
+ * @param k Number of neighbors to search for.
* @param neighbors Matrix holding output neighbors.
* @param distances Matrix holding output distances.
*/
void BaseCase(const size_t queryIndex,
const arma::uvec& referenceIndices,
+ const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances) const;
@@ -315,38 +318,19 @@ class LSHSearch
* @param queryIndex The index of the query in question
* @param referenceIndices The vector of indices of candidate neighbors for
* the query.
+ * @param k Number of neighbors to search for.
* @param querySet Set of query points.
* @param neighbors Matrix holding output neighbors.
* @param distances Matrix holding output distances.
*/
void BaseCase(const size_t queryIndex,
const arma::uvec& referenceIndices,
+ const size_t k,
const arma::mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances) const;
/**
- * This is a helper function that efficiently inserts better neighbor
- * candidates into an existing set of neighbor candidates. This function is
- * only called by the 'BaseCase' function.
- *
- * @param distances Matrix holding output distances.
- * @param neighbors Matrix holding output neighbors.
- * @param queryIndex This is the index of the query being processed currently
- * @param pos The position of the neighbor candidate in the current list of
- * neighbor candidates.
- * @param neighbor The neighbor candidate that is being inserted into the list
- * of the best 'k' candidates for the query in question.
- * @param distance The distance of the query to the neighbor candidate.
- */
- void InsertNeighbor(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance) const;
-
- /**
* This function implements the core idea behind Multiprobe LSH. It is called
* by ReturnIndicesFromTables when T > 0. Given a query's code and its
* projection location, GetAdditionalProbingBins will calculate the T most
@@ -444,6 +428,28 @@ class LSHSearch
//! The number of distance evaluations.
size_t distanceEvaluations;
+ //! Candidate represents a possible candidate neighbor (from the reference
+ // set).
+ struct Candidate
+ {
+ //! Distance between the reference point and the query point.
+ double dist;
+ //! Index of the reference point.
+ size_t index;
+ //! Trivial constructor.
+ Candidate(double d, size_t i) :
+ dist(d),
+ index(i)
+ {};
+ //! Compare the distance of two candidates.
+ friend bool operator<(const Candidate& l, const Candidate& r)
+ {
+ return !SortPolicy::IsBetter(r.dist, l.dist);
+ };
+ };
+
+ //! Use a priority queue to represent the list of candidate neighbors.
+ typedef std::priority_queue<Candidate> CandidateList;
}; // class LSHSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index bcb5795..d0b53ae 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -262,40 +262,22 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
<< std::endl;
}
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance) const
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- const size_t len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
-}
-
// Base case where the query set is the reference set. (So, we can't return
// ourselves as the nearest neighbor.)
template<typename SortPolicy>
inline force_inline
void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
const arma::uvec& referenceIndices,
+ const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances) const
{
+ // Let's build the list of candidate neighbors for the given query point.
+ // It will be initialized with k candidates:
+ // (WorstDistance, referenceSet->n_cols)
+ const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+ std::vector<Candidate> vect(k, def);
+ CandidateList pqueue(std::less<Candidate>(), std::move(vect));
for (size_t j = 0; j < referenceIndices.n_elem; ++j)
{
@@ -308,17 +290,20 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceSet->unsafe_col(queryIndex),
referenceSet->unsafe_col(referenceIndex));
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
- distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
- referenceIndex, distance);
+ Candidate c(distance, referenceIndex);
+ // If this distance is better than the worst candidate, let's insert it.
+ if (c < pqueue.top())
+ {
+ pqueue.pop();
+ pqueue.push(c);
+ }
+ }
+
+ for (size_t j = 1; j <= k; j++)
+ {
+ neighbors(k - j, queryIndex) = pqueue.top().index;
+ distances(k - j, queryIndex) = pqueue.top().dist;
+ pqueue.pop();
}
}
@@ -327,10 +312,18 @@ template<typename SortPolicy>
inline force_inline
void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
const arma::uvec& referenceIndices,
+ const size_t k,
const arma::mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances) const
{
+ // Let's build the list of candidate neighbors for the given query point.
+ // It will be initialized with k candidates:
+ // (WorstDistance, referenceSet->n_cols)
+ const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+ std::vector<Candidate> vect(k, def);
+ CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
for (size_t j = 0; j < referenceIndices.n_elem; ++j)
{
const size_t referenceIndex = referenceIndices[j];
@@ -338,20 +331,23 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
querySet.unsafe_col(queryIndex),
referenceSet->unsafe_col(referenceIndex));
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
- distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
- referenceIndex, distance);
+ Candidate c(distance, referenceIndex);
+ // If this distance is better than the worst candidate, let's insert it.
+ if (c < pqueue.top())
+ {
+ pqueue.pop();
+ pqueue.push(c);
+ }
+ }
+ for (size_t j = 1; j <= k; j++)
+ {
+ neighbors(k - j, queryIndex) = pqueue.top().index;
+ distances(k - j, queryIndex) = pqueue.top().dist;
+ pqueue.pop();
}
}
+
template<typename SortPolicy>
inline force_inline
double LSHSearch<SortPolicy>::PerturbationScore(
@@ -794,8 +790,6 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
// Set the size of the neighbor and distance matrices.
resultingNeighbors.set_size(k, querySet.n_cols);
distances.set_size(k, querySet.n_cols);
- distances.fill(SortPolicy::WorstDistance());
- resultingNeighbors.fill(referenceSet->n_cols);
// If the user asked for 0 nearest neighbors... uh... we're done.
if (k == 0)
@@ -854,7 +848,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
// Sequentially go through all the candidates and save the best 'k'
// candidates.
- BaseCase(i, refIndices, querySet, resultingNeighbors, distances);
+ BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances);
}
Timer::Stop("computing_neighbors");
@@ -877,8 +871,6 @@ Search(const size_t k,
// This is monochromatic search; the query set is the reference set.
resultingNeighbors.set_size(k, referenceSet->n_cols);
distances.set_size(k, referenceSet->n_cols);
- distances.fill(SortPolicy::WorstDistance());
- resultingNeighbors.fill(referenceSet->n_cols);
// If the user requested more than the available number of additional probing
// bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
@@ -933,7 +925,7 @@ Search(const size_t k,
// Sequentially go through all the candidates and save the best 'k'
// candidates.
- BaseCase(i, refIndices, resultingNeighbors, distances);
+ BaseCase(i, refIndices, k, resultingNeighbors, distances);
}
Timer::Stop("computing_neighbors");
More information about the mlpack-git
mailing list