[mlpack-git] master: Merge branch 'lshopt' of https://github.com/rcurtin/mlpack into rcurtin-lshopt (9f162b6)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jun 14 19:19:49 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8d7e5db0bed8fc236407bdc5dee00d716d72a5ab...ae6c9e63b56c1ed1faa9aef9352854bbeb826a2f
>---------------------------------------------------------------
commit 9f162b600834323d684ad7cd59a670d1209573ac
Merge: 8d7e5db 9c28c08
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Jun 14 19:19:49 2016 -0400
Merge branch 'lshopt' of https://github.com/rcurtin/mlpack into rcurtin-lshopt
>---------------------------------------------------------------
9f162b600834323d684ad7cd59a670d1209573ac
src/mlpack/methods/lsh/lsh_main.cpp | 3 +-
src/mlpack/methods/lsh/lsh_search.hpp | 59 +++++---
src/mlpack/methods/lsh/lsh_search_impl.hpp | 231 ++++++++++++++++++-----------
src/mlpack/tests/serialization_test.cpp | 12 +-
4 files changed, 196 insertions(+), 109 deletions(-)
diff --cc src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1,98acad1..b34d22c
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@@ -377,86 -369,25 +369,75 @@@ void LSHSearch<SortPolicy>::ReturnIndic
Log::Assert(hashVec.n_elem == numTablesToSearch);
- // Count number of points hashed in the same bucket as the query
- // For all the buckets that the query is hashed into, sequentially
- // collect the indices in those buckets.
- arma::Col<size_t> refPointsConsidered;
- refPointsConsidered.zeros(referenceSet->n_cols);
-
- for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
++ // Count number of points hashed in the same bucket as the query.
+ size_t maxNumPoints = 0;
- for (size_t i = 0; i < numTablesToSearch; ++i) //For all tables
++ for (size_t i = 0; i < numTablesToSearch; ++i)
{
- size_t hashInd = (size_t) hashVec[i]; //find query's bucket
- maxNumPoints += bucketContentSize[hashInd]; //count bucket contents
+ const size_t hashInd = (size_t) hashVec[i];
+ const size_t tableRow = bucketRowInHashTable[hashInd];
++ if (tableRow != secondHashSize)
++ maxNumPoints += bucketContentSize[tableRow];
+ }
+
-
+ // There are two ways to proceed here:
+ // Either allocate a maxNumPoints-size vector, place all candidates, and run
+ // unique on the vector to discard duplicates.
+ // Or allocate a referenceSet->n_cols size vector (i.e. number of reference
+ // points) of zeros, and mark found indices as 1.
+ // Option 1 runs faster for small maxNumPoints but worse for larger values, so
+ // we choose based on a heuristic.
+ const float cutoff = 0.1;
+ const float selectivity = static_cast<float>(maxNumPoints) /
+ static_cast<float>(referenceSet->n_cols);
+
+ if (selectivity > cutoff)
+ {
+ // Heuristic: larger maxNumPoints means we should use find() because it
+ // should be faster.
+ // Reference points hashed in the same bucket as the query are set to >0.
+ arma::Col<size_t> refPointsConsidered;
+ refPointsConsidered.zeros(referenceSet->n_cols);
- if ((tableRow != secondHashSize) && (bucketContentSize[tableRow] > 0))
+ for (size_t i = 0; i < hashVec.n_elem; ++i)
{
- size_t hashInd = (size_t) hashVec[i];
++ const size_t hashInd = (size_t) hashVec[i];
++ const size_t tableRow = bucketRowInHashTable[hashInd];
+
- if (bucketContentSize[hashInd] > 0)
- {
- // Pick the indices in the bucket corresponding to hashInd.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
-
- for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
- refPointsConsidered[secondHashTable(tableRow, j)]++;
- }
+ // Pick the indices in the bucket corresponding to 'hashInd'.
- for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
- refPointsConsidered[secondHashTable[tableRow](j)]++;
++ if (tableRow != secondHashSize)
++ for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
++ refPointsConsidered[secondHashTable[tableRow](j)]++;
}
+
+ // Only keep reference points found in at least one bucket.
+ referenceIndices = arma::find(refPointsConsidered > 0);
+ return;
}
+ else
+ {
+ // Heuristic: smaller maxNumPoints means we should use unique() because it
+ // should be faster.
+ // Allocate space for the query's potential neighbors.
+ arma::uvec refPointsConsideredSmall;
+ refPointsConsideredSmall.zeros(maxNumPoints);
+
+ // Retrieve candidates.
+ size_t start = 0;
+ for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
+ {
- size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
++ const size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
++ const size_t tableRow = bucketRowInHashTable[hashInd];
+
- if (bucketContentSize[hashInd] > 0)
- {
- // tableRow hash indices corresponding to query.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
-
- // This for-loop could be replaced with a vector slice (TODO).
- // Store all secondHashTable points in the candidates set.
- for (size_t j = 0; j < bucketContentSize[hashInd]; ++j)
- refPointsConsideredSmall(start++) = secondHashTable(tableRow, j);
- }
++ // Store all secondHashTable points in the candidates set.
++ if (tableRow != secondHashSize)
++ for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
++ refPointsConsideredSmall(start++) = secondHashTable[tableRow][j];
+ }
- referenceIndices = arma::find(refPointsConsidered > 0);
+ // Only keep unique candidates.
+ referenceIndices = arma::unique(refPointsConsideredSmall);
+ return;
+ }
}
// Search for nearest neighbors in a given query set.
More information about the mlpack-git
mailing list