[mlpack-git] master: Implements hybrid find/unique solution. Code still slower than original (8462632)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Jun 5 15:02:48 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/56371c80f3165978e61f010fc0bd852d8d629266...80685304929965115306ba609504840a9f665066
>---------------------------------------------------------------
commit 8462632b70fcadb3334a1ac19788e1cecc3fbae1
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Mon May 9 13:49:07 2016 +0300
Implements hybrid find/unique solution. Code still slower than original
>---------------------------------------------------------------
8462632b70fcadb3334a1ac19788e1cecc3fbae1
src/mlpack/methods/lsh/lsh_search_impl.hpp | 88 ++++++++++++++++++++++++------
1 file changed, 70 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ea57564..109001a 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -218,7 +218,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
Log::Assert(hashVec.n_elem == numTablesToSearch);
- //Count number of points hashed in the same bucket as the query
+ // Count number of points hashed in the same bucket as the query
size_t maxNumPoints = 0;
for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
{
@@ -226,29 +226,81 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
}
- //Allocate space for query's potential neighbors
- arma::uvec refPointsConsideredSmall;
- refPointsConsideredSmall.zeros(maxNumPoints);
- //Retrieve candidates
- size_t start = 0;
- for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
+ // There are two ways to proceed here:
+ // Either allocate a maxNumPoints-size vector, place all candidates, and run
+ // unique on the vector to discard duplicates.
+ // Or allocate a referenceSet->n_cols size vector (i.e. number of reference
+ // points) of zeros, and mark found indices as 1.
+ // Option 1 runs faster for small maxNumPoints but worse for larger values, so
+ // we choose based on a heuristic.
+
+ const float cutoff = 0.1;
+ const float selectivity =
+ static_cast<float>(maxNumPoints) / static_cast<float>(referenceSet->n_cols);
+
+ if ( selectivity > cutoff )
{
- size_t hashInd = (size_t) hashVec[i]; //find query's bucket
+ // Heuristic: larger maxNumPoints, use find()
+
+ // Reference points hashed in the same bucket as the query are set to >0
+ arma::Col<size_t> refPointsConsidered;
+ refPointsConsidered.zeros(referenceSet->n_cols);
+
+ for (size_t i = 0; i < hashVec.n_elem; ++i)
+ {
+ size_t hashInd = (size_t) hashVec[i];
+
+ if (bucketContentSize[hashInd] > 0)
+ {
+ // Pick the indices in the bucket corresponding to hashInd
+ size_t tableRow = bucketRowInHashTable[hashInd];
+ assert(tableRow < secondHashSize);
+ assert(tableRow < secondHashTable.n_rows);
- //tableRow hash indices corresponding to query
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
+ for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+ refPointsConsidered[secondHashTable(tableRow, j)]++;
+ }
+ }
- //this for-loop could be replaced with a vector slice (TODO)
- //store all secondHashTable points in the candidates set
- for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
- refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+ //only keep reference points found in some bucket
+ referenceIndices = arma::find(refPointsConsidered > 0);
+ return;
+ }
+
+ else
+
+ {
+ // Heuristic: smaller maxNumPoints, use unique()
+ // Allocate space for query's potential neighbors
+ arma::uvec refPointsConsideredSmall;
+ refPointsConsideredSmall.zeros(maxNumPoints);
+
+ // Retrieve candidates
+ size_t start = 0;
+ for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
+ {
+ size_t hashInd = (size_t) hashVec[i]; // find query's bucket
+
+ if (bucketContentSize[hashInd] > 0)
+ {
+ // tableRow hash indices corresponding to query
+ size_t tableRow = bucketRowInHashTable[hashInd];
+ assert(tableRow < secondHashSize);
+ assert(tableRow < secondHashTable.n_rows);
+
+ // this for-loop could be replaced with a vector slice (TODO)
+ // store all secondHashTable points in the candidates set
+ for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
+ refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
+ }
+ }
+
+ // Only keep unique candidates
+ referenceIndices = arma::unique(refPointsConsideredSmall);
+ return;
}
- //Only keep unique candidates
- referenceIndices = arma::unique(refPointsConsideredSmall);
}
// Search for nearest neighbors in a given query set.
More information about the mlpack-git
mailing list