[mlpack-git] master: Merge branch 'lshopt' of https://github.com/rcurtin/mlpack into rcurtin-lshopt (9f162b6)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jun 14 19:19:49 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8d7e5db0bed8fc236407bdc5dee00d716d72a5ab...ae6c9e63b56c1ed1faa9aef9352854bbeb826a2f

>---------------------------------------------------------------

commit 9f162b600834323d684ad7cd59a670d1209573ac
Merge: 8d7e5db 9c28c08
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Jun 14 19:19:49 2016 -0400

    Merge branch 'lshopt' of https://github.com/rcurtin/mlpack into rcurtin-lshopt


>---------------------------------------------------------------

9f162b600834323d684ad7cd59a670d1209573ac
 src/mlpack/methods/lsh/lsh_main.cpp        |   3 +-
 src/mlpack/methods/lsh/lsh_search.hpp      |  59 +++++---
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 231 ++++++++++++++++++-----------
 src/mlpack/tests/serialization_test.cpp    |  12 +-
 4 files changed, 196 insertions(+), 109 deletions(-)

diff --cc src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1,98acad1..b34d22c
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@@ -377,86 -369,25 +369,75 @@@ void LSHSearch<SortPolicy>::ReturnIndic
  
    Log::Assert(hashVec.n_elem == numTablesToSearch);
  
-   // Count number of points hashed in the same bucket as the query
 -  // For all the buckets that the query is hashed into, sequentially
 -  // collect the indices in those buckets.
 -  arma::Col<size_t> refPointsConsidered;
 -  refPointsConsidered.zeros(referenceSet->n_cols);
 -
 -  for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
++  // Count number of points hashed in the same bucket as the query.
 +  size_t maxNumPoints = 0;
-   for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
++  for (size_t i = 0; i < numTablesToSearch; ++i)
    {
-     size_t hashInd = (size_t) hashVec[i]; //find query's bucket
-     maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+     const size_t hashInd = (size_t) hashVec[i];
+     const size_t tableRow = bucketRowInHashTable[hashInd];
++    if (tableRow != secondHashSize)
++      maxNumPoints += bucketContentSize[tableRow];
 +  }
 +
- 
 +  // There are two ways to proceed here:
 +  // Either allocate a maxNumPoints-size vector, place all candidates, and run
 +  // unique on the vector to discard duplicates.
 +  // Or allocate a referenceSet->n_cols size vector (i.e. number of reference
 +  // points) of zeros, and mark found indices as 1.
 +  // Option 1 runs faster for small maxNumPoints but worse for larger values, so
 +  // we choose based on a heuristic.
 +  const float cutoff = 0.1;
 +  const float selectivity = static_cast<float>(maxNumPoints) /
 +      static_cast<float>(referenceSet->n_cols);
 +
 +  if (selectivity > cutoff)
 +  {
 +    // Heuristic: larger maxNumPoints means we should use find() because it
 +    // should be faster.
 +    // Reference points hashed in the same bucket as the query are set to >0.
 +    arma::Col<size_t> refPointsConsidered;
 +    refPointsConsidered.zeros(referenceSet->n_cols);
  
 -    if ((tableRow != secondHashSize) && (bucketContentSize[tableRow] > 0))
 +    for (size_t i = 0; i < hashVec.n_elem; ++i)
      {
-       size_t hashInd = (size_t) hashVec[i];
++      const size_t hashInd = (size_t) hashVec[i];
++      const size_t tableRow = bucketRowInHashTable[hashInd];
 +
-       if (bucketContentSize[hashInd] > 0)
-       {
-         // Pick the indices in the bucket corresponding to hashInd.
-         size_t tableRow = bucketRowInHashTable[hashInd];
-         assert(tableRow < secondHashSize);
-         assert(tableRow < secondHashTable.n_rows);
- 
-         for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
-           refPointsConsidered[secondHashTable(tableRow, j)]++;
-       }
+       // Pick the indices in the bucket corresponding to 'hashInd'.
 -      for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
 -        refPointsConsidered[secondHashTable[tableRow](j)]++;
++      if (tableRow != secondHashSize)
++        for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
++          refPointsConsidered[secondHashTable[tableRow](j)]++;
      }
 +
 +    // Only keep reference points found in at least one bucket.
 +    referenceIndices = arma::find(refPointsConsidered > 0);
 +    return;
    }
 +  else
 +  {
 +    // Heuristic: smaller maxNumPoints means we should use unique() because it
 +    // should be faster.
 +    // Allocate space for the query's potential neighbors.
 +    arma::uvec refPointsConsideredSmall;
 +    refPointsConsideredSmall.zeros(maxNumPoints);
 +
 +    // Retrieve candidates.
 +    size_t start = 0;
 +    for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
 +    {
-       size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
++      const size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
++      const size_t tableRow = bucketRowInHashTable[hashInd];
 +
-       if (bucketContentSize[hashInd] > 0)
-       {
-         // tableRow hash indices corresponding to query.
-         size_t tableRow = bucketRowInHashTable[hashInd];
-         assert(tableRow < secondHashSize);
-         assert(tableRow < secondHashTable.n_rows);
- 
-         // This for-loop could be replaced with a vector slice (TODO).
-         // Store all secondHashTable points in the candidates set.
-         for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
-           refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
-       }
++      // Store all secondHashTable points in the candidates set.
++      if (tableRow != secondHashSize)
++        for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
++          refPointsConsideredSmall(start++) = secondHashTable[tableRow][j];
 +    }
  
 -  referenceIndices = arma::find(refPointsConsidered > 0);
 +    // Only keep unique candidates.
 +    referenceIndices = arma::unique(refPointsConsideredSmall);
 +    return;
 +  }
  }
  
  // Search for nearest neighbors in a given query set.




More information about the mlpack-git mailing list