[mlpack-git] master: Implementation of Multiprobe LSH, version 1 (8dd409d)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jun 30 15:11:40 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc
>---------------------------------------------------------------
commit 8dd409d2d7e954b0c92200b32d075a4ebd3f4902
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Wed Jun 8 19:51:04 2016 +0300
Implementation of Multiprobe LSH, version 1
>---------------------------------------------------------------
8dd409d2d7e954b0c92200b32d075a4ebd3f4902
src/mlpack/methods/lsh/lsh_main.cpp | 7 +-
src/mlpack/methods/lsh/lsh_search.hpp | 17 +-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 375 ++++++++++++++++++++++++++---
src/mlpack/prereqs.hpp | 1 +
4 files changed, 363 insertions(+), 37 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 2894411..d291b72 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -63,6 +63,8 @@ PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
"LSH preprocessing. By default, the LSH class automatically estimates a "
"hash width for its use.", "H", 0.0);
+PARAM_INT("num_probes", "Number of additional probes for Multiprobe LSH"
+ " If 0, traditional LSH is used.", "T", 0);
PARAM_INT("second_hash_size", "The size of the second level hash table.", "S",
99901);
PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
@@ -135,6 +137,7 @@ int main(int argc, char *argv[])
const size_t numProj = CLI::GetParam<int>("projections");
const size_t numTables = CLI::GetParam<int>("tables");
const double hashWidth = CLI::GetParam<double>("hash_width");
+ const size_t numProbes = CLI::GetParam<int>("num_probes");
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -178,11 +181,11 @@ int main(int argc, char *argv[])
Log::Info << "Loaded query data from '" << queryFile << "' ("
<< queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
}
- allkann.Search(queryData, k, neighbors, distances);
+ allkann.Search(queryData, k, neighbors, distances, numProbes);
}
else
{
- allkann.Search(k, neighbors, distances);
+ allkann.Search(k, neighbors, distances, numProbes);
}
}
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..b38dc4f 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -143,7 +143,8 @@ class LSHSearch
const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- const size_t numTablesToSearch = 0);
+ const size_t numTablesToSearch = 0,
+ size_t T = 0);
/**
* Compute the nearest neighbors and store the output in the given matrices.
@@ -166,7 +167,8 @@ class LSHSearch
void Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- const size_t numTablesToSearch = 0);
+ const size_t numTablesToSearch = 0,
+ size_t T = 0);
/**
* Serialize the LSH model.
@@ -229,7 +231,8 @@ class LSHSearch
template<typename VecType>
void ReturnIndicesFromTable(const VecType& queryPoint,
arma::uvec& referenceIndices,
- size_t numTablesToSearch) const;
+ size_t numTablesToSearch,
+ const size_t T) const;
/**
* This is a helper function that computes the distance of the query to the
@@ -286,6 +289,14 @@ class LSHSearch
const size_t neighbor,
const double distance) const;
+ /**
+ TODO: Document this
+ */
+ void GetAdditionalProbingBins(const arma::vec &queryCode,
+ const arma::vec &queryCodeNotFloored,
+ const size_t T,
+ arma::mat &additionalProbingBins) const;
+
//! Reference dataset.
const arma::mat* referenceSet;
//! If true, we own the reference set.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1..027309a 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -339,12 +339,243 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceIndex, distance);
}
+
+// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins
+class CompareGreater
+{
+ public:
+ bool operator()(
+ std::pair<double, size_t> p1,
+ std::pair<double, size_t> p2){
+ //only compare the double values
+ return p1.first > p2.first;
+ }
+};
+
+//Returns the score of a perturbation vector generated by perturbation set A
+//The score of a pertubation set (vector) is the sum of scores of the
+//participating actions
+inline double perturbationScore(
+ const std::vector<size_t> &A,
+ const arma::vec &scores)
+{
+ double score = 0.0;
+ for (size_t i = 0; i < A.size(); ++i)
+ score+=scores[A[i]];
+ return score;
+}
+
+// Replace max element with max element+1 in perturbation set A
+inline void perturbationShift(std::vector<size_t> &A)
+{
+ size_t max_pos = 0;
+ size_t max = A[0];
+ for (size_t i = 1; i < A.size(); ++i)
+ {
+ if (A[i] > max)
+ {
+ max = A[i];
+ max_pos = i;
+ }
+ }
+ A[max_pos]++;
+}
+
+// Add 1+max element to perturbation set A
+inline void perturbationExpand(std::vector<size_t> &A)
+{
+ size_t max = A[0];
+ for (size_t i = 1; i < A.size(); ++i)
+ if (A[i] > max)
+ max = A[i];
+ A.push_back(max+1);
+}
+
+// Return true if perturbation set A is valid. A perturbation set is invalid if
+// it contains two (or more) actions for the same dimension or dimensions that
+// are larger than the queryCode's dimensions.
+inline bool perturbationValid(
+ const std::vector<size_t> &A,
+ const size_t numProj)
+{
+ //stack allocation and initialization to 0 (bool check[numProj] = {0}) made
+ //some compilers complain so use new to be safe...
+ bool *check = new bool[numProj]();
+
+ for (size_t i = 0; i < A.size(); ++i)
+ {
+ if ( A[i] >= 2*numProj)
+ {
+ delete []check;
+ return false;
+ }
+
+ //check that we only see each dimension once
+ if (check[A[i] % numProj ] == 0)
+ check[A[i] % numProj ] = 1;
+ else
+ {
+ delete []check;
+ return false;
+ }
+ }
+ delete []check;
+ return true;
+}
+
+
+// Compute additional probing bins for a query
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
+ const arma::vec &queryCode,
+ const arma::vec &queryCodeNotFloored,
+ const size_t T,
+ arma::mat &additionalProbingBins) const
+{
+
+ if (T == 0)
+ return;
+
+ // Each column of additionalProbingBins is the code of a bin.
+ additionalProbingBins.set_size(numProj, T);
+
+ // Copy the query's code, then add/subtract according to perturbations
+ for (size_t c = 0; c < T; ++c)
+ additionalProbingBins.col(c) = queryCode;
+
+
+ // Calculate query point's projection position
+ arma::mat projection = queryCode * hashWidth;
+
+ // Use projection to calculate query's distance from hash limits
+ arma::vec limLow = queryCodeNotFloored - projection;
+ arma::vec limHigh = hashWidth - limLow;
+
+ // calculate scores = distances^2
+ arma::vec scores(2 * numProj);
+ scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
+ scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2);
+
+ // actions vector shows what transformation to apply to a coordinate
+ arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
+
+ actions.rows(0, numProj - 1) = // first numProj rows
+ -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
+
+ actions.rows(numProj, 2 * numProj - 1) = // last numProj rows
+ arma::ones< arma::Col<short int> > (numProj); // 1s
+
+
+ // acting dimension vector shows which coordinate to transform according to
+ // actions described by actions vector
+ arma::Col<size_t> positions(2 * numProj); // will be [0 1 2 ... 0 1 2 ...]
+ positions.rows(0, numProj - 1) =
+ arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+ positions.rows(numProj, 2 * numProj - 1) =
+ arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+
+ // sort everything in increasing order
+ arma::Col<long long unsigned int> sortidx = arma::sort_index(scores);
+ scores = scores(sortidx);
+ actions = actions(sortidx);
+ positions = positions(sortidx);
+
+
+ // Theory:
+ // From the paper: This is the part that creates the probing sequence
+ // A probing sequence is a sequence of T probing bins where query's
+ // neighbors are most likely to be. Likelihood is dependent only on a bin's
+ // score, which is the sum of scores of all dimension-action pairs, so we
+ // need to calculate the T smallest sums of scores that are not conflicting.
+ //
+ // Method:
+ // Store each perturbation set (pair of (dimension, action)) in a
+ // std::vector. Create a minheap of scores, with each node pointing to its
+ // relevant perturbation set. Each perturbation set popped from the minheap
+ // is the next most likely perturbation set.
+ // Transform perturbation set to perturbation vector by setting the
+ // dimensions specified by the set to queryCode+action (action is {-1, 1}).
+
+ std::vector<size_t> Ao;
+ Ao.push_back(0); // initial perturbation holds smallest score (0 if sorted)
+
+ std::vector< std::vector<size_t> > perturbationSets;
+ perturbationSets.push_back(Ao); // storage of perturbation sets
+
+
+ // define a priority queue with CompareGreater as a minheap
+ std::priority_queue<
+ std::pair<double, size_t>, // contents: pairs of (score, index)
+ std::vector< // container: vector of pairs
+ std::pair<double, size_t>
+ >,
+ mlpack::neighbor::CompareGreater // comparator of pairs(compare scores)
+ > minHeap; // our minheap
+
+ // Start by adding the lowest scoring set to the minheap
+ std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+ minHeap.push(pair0);
+
+ double prevScore = 0; // store score of smallest inserted vector (for assert)
+ // loop invariable: after pvec iterations, additionalProbingBins contains pvec
+ // valid codes of the highest-scoring bins
+ for (size_t pvec = 0; pvec < T; ++pvec)
+ {
+ std::vector<size_t> Ai;
+ do
+ {
+ // get the perturbation set corresponding to the minimum score
+ Ai = perturbationSets[ minHeap.top().second ];
+ minHeap.pop(); // .top() returns, .pop() removes
+
+
+ // modify Ai (shift)
+ std::vector<size_t> As = Ai;
+ perturbationShift(As);
+ if ( perturbationValid(As, numProj) )
+ {
+ perturbationSets.push_back(As); // add shifted set to sets
+ std::pair<double, size_t> shifted(
+ perturbationScore(As, scores),
+ perturbationSets.size() - 1); // (score, position) pair for shift
+ minHeap.push(shifted);
+ }
+
+ // modify Ai (expand)
+ std::vector<size_t> Ae = Ai;
+ perturbationExpand(Ae);
+ if ( perturbationValid(Ae, numProj) )
+ {
+ perturbationSets.push_back(Ae); // add expanded set to sets
+ std::pair<double, size_t> expanded(
+ perturbationScore(Ae, scores),
+ perturbationSets.size() - 1); // (score, position) pair for expand
+ minHeap.push(expanded);
+ }
+
+
+ }while (! perturbationValid(Ai, numProj) );//Discard invalid perturbations
+
+ // a valid perturbation must have higher score than previous valid ones,
+ // meaning the bin it corresponds to is less likely to hold neighbors
+ assert ( perturbationScore(Ai, scores) >= prevScore );
+ prevScore = perturbationScore(Ai, scores);
+
+ // add perturbation vector to probing sequence if valid
+ for (size_t i = 0; i < Ai.size(); ++i)
+ additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
+
+ }
+}
+
+
template<typename SortPolicy>
template<typename VecType>
void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
const VecType& queryPoint,
arma::uvec& referenceIndices,
- size_t numTablesToSearch) const
+ size_t numTablesToSearch,
+ const size_t T) const
{
// Decide on the number of tables to look into.
if (numTablesToSearch == 0) // If no user input is given, search all.
@@ -362,10 +593,10 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
// Compute the projection of the query in each table.
arma::mat allProjInTables(numProj, numTablesToSearch);
+ arma::mat queryCodesNotFloored(numProj, numTablesToSearch);
for (size_t i = 0; i < numTablesToSearch; i++)
- //allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
- allProjInTables.unsafe_col(i) = projections.slice(i).t() * queryPoint;
- allProjInTables += offsets.cols(0, numTablesToSearch - 1);
+ queryCodesNotFloored.unsafe_col(i) = projections.slice(i).t() * queryPoint;
+ queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
allProjInTables /= hashWidth;
// Compute the hash value of each key of the query into a bucket of the
@@ -377,12 +608,47 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
Log::Assert(hashVec.n_elem == numTablesToSearch);
+ // Compute hashVectors of additional probing bins
+ arma::mat hashMat;
+ if (T > 0)
+ {
+ hashMat.set_size(T, numTablesToSearch);
+
+ for (size_t i = 0; i < numTablesToSearch; ++i)
+ {
+ // construct this table's probing sequence of length T
+ arma::mat additionalProbingBins;
+ GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
+ queryCodesNotFloored.unsafe_col(i),
+ T,
+ additionalProbingBins);
+
+ // map the probing bin to second hash table bins
+ hashMat.col(i) = additionalProbingBins.t() * secondHashWeights;
+ for (size_t p = 0; p < T; ++p)
+ hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize);
+ }
+
+ // top row of hashMat is primary bins for each table
+ hashMat = arma::join_vert(hashVec, hashMat);
+ }
+ else
+ {
+ // if not multiprobe, hashMat is only hashVec's elements
+ hashMat.set_size(1, numTablesToSearch);
+ hashMat.row(0) = hashVec;
+ }
+
+
// Count number of points hashed in the same bucket as the query
size_t maxNumPoints = 0;
for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
{
- size_t hashInd = (size_t) hashVec[i]; //find query's bucket
- maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+ for (size_t p = 0; p < T + 1; ++p)
+ {
+ size_t hashInd = (size_t) hashMat(p, i); //find query's bucket
+ maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+ }
}
@@ -405,19 +671,24 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
arma::Col<size_t> refPointsConsidered;
refPointsConsidered.zeros(referenceSet->n_cols);
- for (size_t i = 0; i < hashVec.n_elem; ++i)
+ for (size_t i = 0; i < numTablesToSearch; ++i) // for all tables
{
- size_t hashInd = (size_t) hashVec[i];
-
- if (bucketContentSize[hashInd] > 0)
+ for (size_t p = 0; p < T + 1; ++p) // for entire probing sequence
{
- // Pick the indices in the bucket corresponding to hashInd.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
- for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
- refPointsConsidered[secondHashTable(tableRow, j)]++;
+ // get the sequence code
+ size_t hashInd = (size_t) hashMat(p, i);
+
+ if (bucketContentSize[hashInd] > 0)
+ {
+ // Pick the indices in the bucket corresponding to hashInd.
+ size_t tableRow = bucketRowInHashTable[hashInd];
+ assert(tableRow < secondHashSize);
+ assert(tableRow < secondHashTable.n_rows);
+
+ for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+ refPointsConsidered[secondHashTable(tableRow, j)]++;
+ }
}
}
@@ -437,19 +708,22 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
size_t start = 0;
for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
{
- size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
-
- if (bucketContentSize[hashInd] > 0)
+ for (size_t p = 0; p < T + 1; ++p)
{
- // tableRow hash indices corresponding to query.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
-
- // This for-loop could be replaced with a vector slice (TODO).
- // Store all secondHashTable points in the candidates set.
- for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
- refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+ size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
+
+ if (bucketContentSize[hashInd] > 0)
+ {
+ // tableRow hash indices corresponding to query.
+ size_t tableRow = bucketRowInHashTable[hashInd];
+ assert(tableRow < secondHashSize);
+ assert(tableRow < secondHashTable.n_rows);
+
+ // This for-loop could be replaced with a vector slice (TODO).
+ // Store all secondHashTable points in the candidates set.
+ for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+ refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+ }
}
}
@@ -465,7 +739,8 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- const size_t numTablesToSearch)
+ const size_t numTablesToSearch,
+ size_t T)
{
// Ensure the dimensionality of the query set is correct.
if (querySet.n_rows != referenceSet->n_rows)
@@ -496,6 +771,23 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
if (k == 0)
return;
+ // If the user requested more than the available number of additional probing
+ // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
+ size_t Teffective = T;
+ if (T > ( (size_t) ( (1 << numProj) - 1) ) )
+ {
+ Teffective = ( 1 << numProj ) - 1;
+ Log::Warn << "Requested " << T <<
+ " additional bins are more than theoretical maximum. Using " <<
+ Teffective << " instead." << std::endl;
+ }
+
+ // If the user set multiprobe, log it
+ if (T > 0)
+ Log::Info << "Running multiprobe LSH with " << Teffective <<
+ " additional probing bins per table per query."<< std::endl;
+
+
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
@@ -506,7 +798,8 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
- ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch);
+ ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch,
+ Teffective);
// An informative book-keeping for the number of neighbor candidates
// returned on average.
@@ -533,7 +826,8 @@ void LSHSearch<SortPolicy>::
Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- const size_t numTablesToSearch)
+ const size_t numTablesToSearch,
+ size_t T)
{
// This is monochromatic search; the query set is the reference set.
resultingNeighbors.set_size(k, referenceSet->n_cols);
@@ -541,6 +835,22 @@ Search(const size_t k,
distances.fill(SortPolicy::WorstDistance());
resultingNeighbors.fill(referenceSet->n_cols);
+ // If the user requested more than the available number of additional probing
+ // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
+ size_t Teffective = T;
+ if (T > ( (size_t) ( (1 << numProj) - 1) ) )
+ {
+ Teffective = ( 1 << numProj ) - 1;
+ Log::Warn << "Requested " << T <<
+ " additional bins are more than theoretical maximum. Using " <<
+ Teffective << " instead." << std::endl;
+ }
+
+ // If the user set multiprobe, log it
+ if (T > 0)
+ Log::Info << "Running multiprobe LSH with " << Teffective <<
+ " additional probing bins per table per query."<< std::endl;
+
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
@@ -551,7 +861,8 @@ Search(const size_t k,
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
- ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
+ ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch,
+ Teffective);
// An informative book-keeping for the number of neighbor candidates
// returned on average.
diff --git a/src/mlpack/prereqs.hpp b/src/mlpack/prereqs.hpp
index 3852a6b..4950282 100644
--- a/src/mlpack/prereqs.hpp
+++ b/src/mlpack/prereqs.hpp
@@ -24,6 +24,7 @@
#include <iostream>
#include <stdexcept>
#include <tuple>
+#include <queue>
// Defining _USE_MATH_DEFINES should set M_PI.
#define _USE_MATH_DEFINES
More information about the mlpack-git
mailing list