[mlpack-git] master: Implements hybrid find/unique solution. Code still slower than original (8462632)

gitdub at mlpack.org gitdub at mlpack.org
Sun Jun 5 15:02:48 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/56371c80f3165978e61f010fc0bd852d8d629266...80685304929965115306ba609504840a9f665066

>---------------------------------------------------------------

commit 8462632b70fcadb3334a1ac19788e1cecc3fbae1
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Mon May 9 13:49:07 2016 +0300

    Implements hybrid find/unique solution. Code still slower than original


>---------------------------------------------------------------

8462632b70fcadb3334a1ac19788e1cecc3fbae1
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 88 ++++++++++++++++++++++++------
 1 file changed, 70 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ea57564..109001a 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -218,7 +218,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
 
   Log::Assert(hashVec.n_elem == numTablesToSearch);
 
-  //Count number of points hashed in the same bucket as the query
+  // Count number of points hashed in the same bucket as the query
   size_t maxNumPoints = 0;
   for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
   {
@@ -226,29 +226,81 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
   }
 
-  //Allocate space for query's potential neighbors
-  arma::uvec refPointsConsideredSmall;
-  refPointsConsideredSmall.zeros(maxNumPoints);
 
-  //Retrieve candidates
-  size_t start = 0;
-  for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
+  // There are two ways to proceed here:
+  // Either allocate a maxNumPoints-size vector, place all candidates, and run
+  // unique on the vector to discard duplicates.
+  // Or allocate a referenceSet->n_cols size vector (i.e. number of reference
+  // points) of zeros, and mark found indices as 1.
+  // Option 1 runs faster for small maxNumPoints but worse for larger values, so
+  // we choose based on a heuristic.
+  
+  const float cutoff = 0.1;
+  const float selectivity =
+    static_cast<float>(maxNumPoints) / static_cast<float>(referenceSet->n_cols);
+
+  if ( selectivity > cutoff )
   {
-    size_t hashInd = (size_t) hashVec[i]; //find query's bucket
+    // Heuristic: larger maxNumPoints, use find()
+
+    // Reference points hashed in the same bucket as the query are set to >0
+    arma::Col<size_t> refPointsConsidered;
+    refPointsConsidered.zeros(referenceSet->n_cols);
+
+    for (size_t i = 0; i < hashVec.n_elem; ++i)
+    {
+      size_t hashInd = (size_t) hashVec[i];
+
+      if (bucketContentSize[hashInd] > 0)
+      {
+        // Pick the indices in the bucket corresponding to hashInd
+        size_t tableRow = bucketRowInHashTable[hashInd];
+        assert(tableRow < secondHashSize);
+        assert(tableRow < secondHashTable.n_rows);
 
-    //tableRow hash indices corresponding to query
-    size_t tableRow = bucketRowInHashTable[hashInd];
-    assert(tableRow < secondHashSize);
-    assert(tableRow < secondHashTable.n_rows);
+        for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+          refPointsConsidered[secondHashTable(tableRow, j)]++;
+      }
+    }
 
-    //this for-loop could be replaced with a vector slice (TODO)
-    //store all secondHashTable points in the candidates set
-    for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
-      refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+    //only keep reference points found in some bucket
+    referenceIndices = arma::find(refPointsConsidered > 0);
+    return;
+  }
+
+  else
+  
+  {
+    // Heuristic: smaller maxNumPoints, use unique()
+    // Allocate space for query's potential neighbors
+    arma::uvec refPointsConsideredSmall;
+    refPointsConsideredSmall.zeros(maxNumPoints);
+
+    // Retrieve candidates
+    size_t start = 0;
+    for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
+    {
+      size_t hashInd = (size_t) hashVec[i]; // find query's bucket
+
+      if (bucketContentSize[hashInd] > 0)
+      {
+        // tableRow hash indices corresponding to query
+        size_t tableRow = bucketRowInHashTable[hashInd];
+        assert(tableRow < secondHashSize);
+        assert(tableRow < secondHashTable.n_rows);
+
+        // this for-loop could be replaced with a vector slice (TODO)
+        // store all secondHashTable points in the candidates set
+        for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+          refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+      }
+    }
+
+    // Only keep unique candidates
+    referenceIndices = arma::unique(refPointsConsideredSmall);
+    return;
   }
 
-  //Only keep unique candidates
-  referenceIndices = arma::unique(refPointsConsideredSmall);
 }
 
 // Search for nearest neighbors in a given query set.




More information about the mlpack-git mailing list