[mlpack-git] master: Implementation of Multiprobe LSH, version 1 (8dd409d)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 30 15:11:40 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc

>---------------------------------------------------------------

commit 8dd409d2d7e954b0c92200b32d075a4ebd3f4902
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Wed Jun 8 19:51:04 2016 +0300

    Implementation of Multiprobe LSH, version 1


>---------------------------------------------------------------

8dd409d2d7e954b0c92200b32d075a4ebd3f4902
 src/mlpack/methods/lsh/lsh_main.cpp        |   7 +-
 src/mlpack/methods/lsh/lsh_search.hpp      |  17 +-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 375 ++++++++++++++++++++++++++---
 src/mlpack/prereqs.hpp                     |   1 +
 4 files changed, 363 insertions(+), 37 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 2894411..d291b72 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -63,6 +63,8 @@ PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
 PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
     "LSH preprocessing. By default, the LSH class automatically estimates a "
     "hash width for its use.", "H", 0.0);
+PARAM_INT("num_probes", "Number of additional probes for Multiprobe LSH"
+    " If 0, traditional LSH is used.", "T", 0);
 PARAM_INT("second_hash_size", "The size of the second level hash table.", "S",
     99901);
 PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
@@ -135,6 +137,7 @@ int main(int argc, char *argv[])
   const size_t numProj = CLI::GetParam<int>("projections");
   const size_t numTables = CLI::GetParam<int>("tables");
   const double hashWidth = CLI::GetParam<double>("hash_width");
+  const size_t numProbes = CLI::GetParam<int>("num_probes");
 
   arma::Mat<size_t> neighbors;
   arma::mat distances;
@@ -178,11 +181,11 @@ int main(int argc, char *argv[])
         Log::Info << "Loaded query data from '" << queryFile << "' ("
             << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
       }
-      allkann.Search(queryData, k, neighbors, distances);
+      allkann.Search(queryData, k, neighbors, distances, numProbes);
     }
     else
     {
-      allkann.Search(k, neighbors, distances);
+      allkann.Search(k, neighbors, distances, numProbes);
     }
   }
 
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..b38dc4f 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -143,7 +143,8 @@ class LSHSearch
               const size_t k,
               arma::Mat<size_t>& resultingNeighbors,
               arma::mat& distances,
-              const size_t numTablesToSearch = 0);
+              const size_t numTablesToSearch = 0,
+              size_t T = 0);
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
@@ -166,7 +167,8 @@ class LSHSearch
   void Search(const size_t k,
               arma::Mat<size_t>& resultingNeighbors,
               arma::mat& distances,
-              const size_t numTablesToSearch = 0);
+              const size_t numTablesToSearch = 0,
+              size_t T = 0);
 
   /**
    * Serialize the LSH model.
@@ -229,7 +231,8 @@ class LSHSearch
   template<typename VecType>
   void ReturnIndicesFromTable(const VecType& queryPoint,
                               arma::uvec& referenceIndices,
-                              size_t numTablesToSearch) const;
+                              size_t numTablesToSearch,
+                              const size_t T) const;
 
   /**
    * This is a helper function that computes the distance of the query to the
@@ -286,6 +289,14 @@ class LSHSearch
                       const size_t neighbor,
                       const double distance) const;
 
+  /**
+  TODO: Document this
+  */
+  void GetAdditionalProbingBins(const arma::vec &queryCode,
+                            const arma::vec &queryCodeNotFloored,
+                            const size_t T,
+                            arma::mat &additionalProbingBins) const;
+
   //! Reference dataset.
   const arma::mat* referenceSet;
   //! If true, we own the reference set.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1..027309a 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -339,12 +339,243 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         referenceIndex, distance);
 }
 
+
+// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins
+class CompareGreater
+{
+  public:
+    bool operator()(
+        std::pair<double, size_t> p1,
+        std::pair<double, size_t> p2){
+      //only compare the double values
+      return p1.first > p2.first;
+    }
+};
+
+//Returns the score of a perturbation vector generated by perturbation set A
+//The score of a pertubation set (vector) is the sum of scores of the
+//participating actions
+inline double perturbationScore(
+              const std::vector<size_t> &A,
+              const arma::vec &scores)
+{
+  double score = 0.0;
+  for (size_t i = 0; i < A.size(); ++i)
+    score+=scores[A[i]];
+  return score;
+}
+
+// Replace max element with max element+1 in perturbation set A
+inline void perturbationShift(std::vector<size_t> &A)
+{
+  size_t max_pos = 0;
+  size_t max = A[0];
+  for (size_t i = 1; i < A.size(); ++i)
+  {
+    if (A[i] > max)
+    {
+      max = A[i];
+      max_pos = i;
+    }
+  }
+  A[max_pos]++;
+}
+
+// Add 1+max element to perturbation set A
+inline void perturbationExpand(std::vector<size_t> &A)
+{
+  size_t max = A[0];
+  for (size_t i = 1; i < A.size(); ++i)
+    if (A[i] > max)
+      max = A[i];
+  A.push_back(max+1);
+}
+
+// Return true if perturbation set A is valid. A perturbation set is invalid if
+// it contains two (or more) actions for the same dimension or dimensions that
+// are larger than the queryCode's dimensions.
+inline bool perturbationValid(
+    const std::vector<size_t> &A,
+    const size_t numProj)
+{
+  //stack allocation and initialization to 0 (bool check[numProj] = {0}) made
+  //some compilers complain so use new to be safe...
+  bool *check = new bool[numProj]();
+
+  for (size_t i = 0; i < A.size(); ++i)
+  {
+    if ( A[i] >= 2*numProj)
+    {
+      delete []check;
+      return false;
+    }
+
+    //check that we only see each dimension once
+    if (check[A[i] % numProj ] == 0)
+      check[A[i] % numProj ] = 1;
+    else
+    {
+      delete []check;
+      return false;
+    }
+  }
+  delete []check;
+  return true;
+}
+
+
+// Compute additional probing bins for a query
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
+                            const arma::vec &queryCode,
+                            const arma::vec &queryCodeNotFloored,
+                            const size_t T,
+                            arma::mat &additionalProbingBins) const
+{
+
+  if (T == 0)
+    return;
+
+  // Each column of additionalProbingBins is the code of a bin.
+  additionalProbingBins.set_size(numProj, T);
+
+  // Copy the query's code, then add/subtract according to perturbations
+  for (size_t c = 0; c < T; ++c)
+    additionalProbingBins.col(c) = queryCode;
+
+
+  // Calculate query point's projection position
+  arma::mat projection = queryCode * hashWidth;
+
+  // Use projection to calculate query's distance from hash limits
+  arma::vec limLow = queryCodeNotFloored - projection;
+  arma::vec limHigh = hashWidth - limLow;
+
+  // calculate scores = distances^2
+  arma::vec scores(2 * numProj);
+  scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
+  scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2);
+
+  // actions vector shows what transformation to apply to a coordinate
+  arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
+
+  actions.rows(0, numProj - 1) = // first numProj rows
+    -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
+
+  actions.rows(numProj, 2 * numProj - 1) = // last numProj rows
+    arma::ones< arma::Col<short int> > (numProj); // 1s
+
+
+  // acting dimension vector shows which coordinate to transform according to
+  // actions described by actions vector
+  arma::Col<size_t> positions(2 * numProj); // will be [0 1 2 ... 0 1 2 ...]
+  positions.rows(0, numProj - 1) =
+    arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+  positions.rows(numProj, 2 * numProj - 1) =
+    arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+
+  // sort everything in increasing order
+  arma::Col<long long unsigned int> sortidx = arma::sort_index(scores);
+  scores = scores(sortidx);
+  actions = actions(sortidx);
+  positions = positions(sortidx);
+
+
+  // Theory:
+  // From the paper: This is the part that creates the probing sequence
+  // A probing sequence is a sequence of T probing bins where query's
+  // neighbors are most likely to be. Likelihood is dependent only on a bin's
+  // score, which is the sum of scores of all dimension-action pairs, so we
+  // need to calculate the T smallest sums of scores that are not conflicting.
+  //
+  // Method:
+  // Store each perturbation set (pair of (dimension, action)) in a
+  // std::vector. Create a minheap of scores, with each node pointing to its
+  // relevant perturbation set. Each perturbation set popped from the minheap
+  // is the next most likely perturbation set.
+  // Transform perturbation set to perturbation vector by setting the
+  // dimensions specified by the set to queryCode+action (action is {-1, 1}).
+
+  std::vector<size_t> Ao;
+  Ao.push_back(0); // initial perturbation holds smallest score (0 if sorted)
+
+  std::vector< std::vector<size_t> > perturbationSets;
+  perturbationSets.push_back(Ao); // storage of perturbation sets
+
+
+  // define a priority queue with CompareGreater as a minheap
+  std::priority_queue<
+    std::pair<double, size_t>,        // contents: pairs of (score, index)
+    std::vector<                      // container: vector of pairs
+      std::pair<double, size_t>
+      >,
+    mlpack::neighbor::CompareGreater // comparator of pairs(compare scores)
+  > minHeap; // our minheap
+
+  // Start by adding the lowest scoring set to the minheap
+  std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+  minHeap.push(pair0);
+
+  double prevScore = 0; // store score of smallest inserted vector (for assert)
+  // loop invariable: after pvec iterations, additionalProbingBins contains pvec
+  // valid codes of the highest-scoring bins
+  for (size_t pvec = 0; pvec < T; ++pvec)
+  {
+    std::vector<size_t> Ai;
+    do
+    {
+      // get the perturbation set corresponding to the minimum score
+      Ai = perturbationSets[ minHeap.top().second ];
+      minHeap.pop(); // .top() returns, .pop() removes
+
+
+      // modify Ai (shift)
+      std::vector<size_t> As = Ai;
+      perturbationShift(As);
+      if ( perturbationValid(As, numProj) )
+      {
+        perturbationSets.push_back(As); // add shifted set to sets
+        std::pair<double, size_t> shifted(
+            perturbationScore(As, scores),
+            perturbationSets.size() - 1); // (score, position) pair for shift
+        minHeap.push(shifted);
+      }
+
+      // modify Ai (expand)
+      std::vector<size_t> Ae = Ai;
+      perturbationExpand(Ae);
+      if ( perturbationValid(Ae, numProj) )
+      {
+        perturbationSets.push_back(Ae); // add expanded set to sets
+        std::pair<double, size_t> expanded(
+            perturbationScore(Ae, scores),
+            perturbationSets.size() - 1); // (score, position) pair for expand
+        minHeap.push(expanded);
+      }
+
+
+    }while (! perturbationValid(Ai, numProj)  );//Discard invalid perturbations
+
+    // a valid perturbation must have higher score than previous valid ones,
+    // meaning the bin it corresponds to is less likely to hold neighbors
+    assert ( perturbationScore(Ai, scores) >= prevScore );
+    prevScore = perturbationScore(Ai, scores);
+
+    // add perturbation vector to probing sequence if valid
+    for (size_t i = 0; i < Ai.size(); ++i)
+      additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
+
+  }
+}
+
+
 template<typename SortPolicy>
 template<typename VecType>
 void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     const VecType& queryPoint,
     arma::uvec& referenceIndices,
-    size_t numTablesToSearch) const
+    size_t numTablesToSearch,
+    const size_t T) const
 {
   // Decide on the number of tables to look into.
   if (numTablesToSearch == 0) // If no user input is given, search all.
@@ -362,10 +593,10 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
 
   // Compute the projection of the query in each table.
   arma::mat allProjInTables(numProj, numTablesToSearch);
+  arma::mat queryCodesNotFloored(numProj, numTablesToSearch);
   for (size_t i = 0; i < numTablesToSearch; i++)
-    //allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
-    allProjInTables.unsafe_col(i) = projections.slice(i).t() * queryPoint;
-  allProjInTables += offsets.cols(0, numTablesToSearch - 1);
+    queryCodesNotFloored.unsafe_col(i) = projections.slice(i).t() * queryPoint;
+  queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables /= hashWidth;
 
   // Compute the hash value of each key of the query into a bucket of the
@@ -377,12 +608,47 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
 
   Log::Assert(hashVec.n_elem == numTablesToSearch);
 
+  // Compute hashVectors of additional probing bins
+  arma::mat hashMat;
+  if (T > 0)
+  {
+    hashMat.set_size(T, numTablesToSearch);
+
+    for (size_t i = 0; i < numTablesToSearch; ++i)
+    {
+      // construct this table's probing sequence of length T
+      arma::mat additionalProbingBins;
+      GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
+                                queryCodesNotFloored.unsafe_col(i),
+                                T,
+                                additionalProbingBins);
+
+      // map the probing bin to second hash table bins
+      hashMat.col(i) = additionalProbingBins.t() * secondHashWeights;
+      for (size_t p = 0; p < T; ++p)
+        hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize);
+    }
+
+    // top row of hashMat is primary bins for each table
+    hashMat = arma::join_vert(hashVec, hashMat);
+  }
+  else
+  {
+    // if not multiprobe, hashMat is only hashVec's elements
+    hashMat.set_size(1, numTablesToSearch);
+    hashMat.row(0) = hashVec;
+  }
+
+
   // Count number of points hashed in the same bucket as the query
   size_t maxNumPoints = 0;
   for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
   {
-    size_t hashInd = (size_t) hashVec[i]; //find query's bucket
-    maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+    for (size_t p = 0; p < T + 1; ++p)
+    {
+      size_t hashInd = (size_t) hashMat(p, i); //find query's bucket
+      maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+    }
   }
 
 
@@ -405,19 +671,24 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     arma::Col<size_t> refPointsConsidered;
     refPointsConsidered.zeros(referenceSet->n_cols);
 
-    for (size_t i = 0; i < hashVec.n_elem; ++i)
+    for (size_t i = 0; i < numTablesToSearch; ++i) // for all tables
     {
-      size_t hashInd = (size_t) hashVec[i];
-
-      if (bucketContentSize[hashInd] > 0)
+      for (size_t p = 0; p < T + 1; ++p) // for entire probing sequence
       {
-        // Pick the indices in the bucket corresponding to hashInd.
-        size_t tableRow = bucketRowInHashTable[hashInd];
-        assert(tableRow < secondHashSize);
-        assert(tableRow < secondHashTable.n_rows);
 
-        for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
-          refPointsConsidered[secondHashTable(tableRow, j)]++;
+        // get the sequence code
+        size_t hashInd = (size_t) hashMat(p, i);
+
+        if (bucketContentSize[hashInd] > 0)
+        {
+          // Pick the indices in the bucket corresponding to hashInd.
+          size_t tableRow = bucketRowInHashTable[hashInd];
+          assert(tableRow < secondHashSize);
+          assert(tableRow < secondHashTable.n_rows);
+
+          for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+            refPointsConsidered[secondHashTable(tableRow, j)]++;
+        }
       }
     }
 
@@ -437,19 +708,22 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     size_t start = 0;
     for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
     {
-      size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
-
-      if (bucketContentSize[hashInd] > 0)
+      for (size_t p = 0; p < T + 1; ++p)
       {
-        // tableRow hash indices corresponding to query.
-        size_t tableRow = bucketRowInHashTable[hashInd];
-        assert(tableRow < secondHashSize);
-        assert(tableRow < secondHashTable.n_rows);
-
-        // This for-loop could be replaced with a vector slice (TODO).
-        // Store all secondHashTable points in the candidates set.
-        for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
-          refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+        size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
+
+        if (bucketContentSize[hashInd] > 0)
+        {
+          // tableRow hash indices corresponding to query.
+          size_t tableRow = bucketRowInHashTable[hashInd];
+          assert(tableRow < secondHashSize);
+          assert(tableRow < secondHashTable.n_rows);
+
+          // This for-loop could be replaced with a vector slice (TODO).
+          // Store all secondHashTable points in the candidates set.
+          for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+            refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+        }
       }
     }
 
@@ -465,7 +739,8 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
                                    const size_t k,
                                    arma::Mat<size_t>& resultingNeighbors,
                                    arma::mat& distances,
-                                   const size_t numTablesToSearch)
+                                   const size_t numTablesToSearch,
+                                   size_t T)
 {
   // Ensure the dimensionality of the query set is correct.
   if (querySet.n_rows != referenceSet->n_rows)
@@ -496,6 +771,23 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
   if (k == 0)
     return;
 
+  // If the user requested more than the available number of additional probing
+  // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
+  size_t Teffective = T;
+  if (T > ( (size_t) ( (1 << numProj) - 1) ) )
+  {
+    Teffective = ( 1 << numProj ) - 1;
+    Log::Warn << "Requested " << T << 
+      " additional bins are more than theoretical maximum. Using " <<
+      Teffective << " instead." << std::endl;
+  }
+
+  // If the user set multiprobe, log it
+  if (T > 0)
+    Log::Info << "Running multiprobe LSH with " << Teffective <<
+      " additional probing bins per table per query."<< std::endl;
+
+
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
@@ -506,7 +798,8 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
     // Hash every query into every hash table and eventually into the
     // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
-    ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch);
+    ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch, 
+        Teffective);
 
     // An informative book-keeping for the number of neighbor candidates
     // returned on average.
@@ -533,7 +826,8 @@ void LSHSearch<SortPolicy>::
 Search(const size_t k,
        arma::Mat<size_t>& resultingNeighbors,
        arma::mat& distances,
-       const size_t numTablesToSearch)
+       const size_t numTablesToSearch,
+       size_t T)
 {
   // This is monochromatic search; the query set is the reference set.
   resultingNeighbors.set_size(k, referenceSet->n_cols);
@@ -541,6 +835,22 @@ Search(const size_t k,
   distances.fill(SortPolicy::WorstDistance());
   resultingNeighbors.fill(referenceSet->n_cols);
 
+  // If the user requested more than the available number of additional probing
+  // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
+  size_t Teffective = T;
+  if (T > ( (size_t) ( (1 << numProj) - 1) ) )
+  {
+    Teffective = ( 1 << numProj ) - 1;
+    Log::Warn << "Requested " << T << 
+      " additional bins are more than theoretical maximum. Using " <<
+      Teffective << " instead." << std::endl;
+  }
+
+  // If the user set multiprobe, log it
+  if (T > 0)
+    Log::Info << "Running multiprobe LSH with " << Teffective <<
+      " additional probing bins per table per query."<< std::endl;
+  
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
@@ -551,7 +861,8 @@ Search(const size_t k,
     // Hash every query into every hash table and eventually into the
     // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
-    ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
+    ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch,
+        Teffective);
 
     // An informative book-keeping for the number of neighbor candidates
     // returned on average.
diff --git a/src/mlpack/prereqs.hpp b/src/mlpack/prereqs.hpp
index 3852a6b..4950282 100644
--- a/src/mlpack/prereqs.hpp
+++ b/src/mlpack/prereqs.hpp
@@ -24,6 +24,7 @@
 #include <iostream>
 #include <stdexcept>
 #include <tuple>
+#include <queue>
 
 // Defining _USE_MATH_DEFINES should set M_PI.
 #define _USE_MATH_DEFINES




More information about the mlpack-git mailing list