[mlpack-git] master: Fixes style issues, optimizes code a bit (ae81ee5)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 22 11:28:58 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc
>---------------------------------------------------------------
commit ae81ee58ac0db373cc7d0f5f95c246f1f258e7e9
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Wed Jun 22 16:28:58 2016 +0100
Fixes style issues, optimizes code a bit
>---------------------------------------------------------------
ae81ee58ac0db373cc7d0f5f95c246f1f258e7e9
src/mlpack/methods/lsh/lsh_search.hpp | 4 +-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 122 ++++++++++++-----------------
2 files changed, 51 insertions(+), 75 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 364ef19..b21b0d1 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -154,13 +154,15 @@ class LSHSearch
* available without having to build hashing for every table size.
* By default, this is set to zero in which case all tables are
* considered.
+ * @param T The number of additional probing bins to examine with multiprobe
+ * LSH. If T = 0, classic single-probe LSH is run (default).
*/
void Search(const arma::mat& querySet,
const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
const size_t numTablesToSearch = 0,
- size_t T = 0);
+ const size_t T = 0);
/**
* Compute the nearest neighbors and store the output in the given matrices.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 36e4afa..a89ea36 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -340,25 +340,11 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceIndex, distance);
}
-
-// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins.
-class CompareGreater
-{
- public:
- bool operator()(
- std::pair<double, size_t> p1,
- std::pair<double, size_t> p2){
- //only compare the double values
- return p1.first > p2.first;
- }
-};
-
//Returns the score of a perturbation vector generated by perturbation set A.
//The score of a pertubation set (vector) is the sum of scores of the
//participating actions.
-inline double perturbationScore(
- const std::vector<size_t> &A,
- const arma::vec &scores)
+inline double perturbationScore(const std::vector<size_t>& A,
+ const arma::vec& scores)
{
double score = 0.0;
for (size_t i = 0; i < A.size(); ++i)
@@ -368,7 +354,7 @@ inline double perturbationScore(
// Inline function used by GetAdditionalProbingBins. The vector shift operation
// replaces the largest element of a vector A with (largest element) + 1.
-inline void perturbationShift(std::vector<size_t> &A)
+inline void perturbationShift(std::vector<size_t>& A)
{
size_t max_pos = 0;
size_t max = A[0];
@@ -386,45 +372,38 @@ inline void perturbationShift(std::vector<size_t> &A)
// Inline function used by GetAdditionalProbingBins. The vector expansion
// operation adds the element [1 + (largest_element)] to a vector A, where
// largest_element is the largest element of A.
-inline void perturbationExpand(std::vector<size_t> &A)
+inline void perturbationExpand(std::vector<size_t>& A)
{
size_t max = A[0];
for (size_t i = 1; i < A.size(); ++i)
if (A[i] > max)
max = A[i];
- A.push_back(max+1);
+ A.push_back(max + 1);
}
// Return true if perturbation set A is valid. A perturbation set is invalid if
// it contains two (or more) actions for the same dimension or dimensions that
// are larger than the queryCode's dimensions.
-inline bool perturbationValid(
- const std::vector<size_t> &A,
- const size_t numProj)
+inline bool perturbationValid(const std::vector<size_t>& A,
+ const size_t numProj)
{
// Stack allocation and initialization to 0 (bool check[numProj] = {0}) made
- // some compilers complain. We use new just to be safe.
- bool *check = new bool[numProj]();
+ // some compilers complain, and std::vector might even be compressed (depends
+ // on implementation) so this saves some space.
+ std::vector<bool> check(numProj);
for (size_t i = 0; i < A.size(); ++i)
{
// Check that we only use valid dimensions. If not, vector is not valid.
- if ( A[i] >= 2*numProj)
- {
- delete []check;
+ if (A[i] >= 2 * numProj)
return false;
- }
// Check that we only see each dimension once. If not, vector is not valid.
if (check[A[i] % numProj ] == 0)
check[A[i] % numProj ] = 1;
else
- {
- delete []check;
return false;
- }
}
- delete []check;
return true;
}
@@ -432,10 +411,10 @@ inline bool perturbationValid(
// Compute additional probing bins for a query
template<typename SortPolicy>
void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
- const arma::vec &queryCode,
- const arma::vec &queryCodeNotFloored,
+ const arma::vec& queryCode,
+ const arma::vec& queryCodeNotFloored,
const size_t T,
- arma::mat &additionalProbingBins) const
+ arma::mat& additionalProbingBins) const
{
// No additional bins requested. Our work is done.
@@ -461,7 +440,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// calculate scores = distances^2
arma::vec scores(2 * numProj);
scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
- scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2);
+ scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
// actions vector shows what transformation to apply to a coordinate
arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
@@ -469,7 +448,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
actions.rows(0, numProj - 1) = // first numProj rows
-1 * arma::ones< arma::Col<short int> > (numProj); // -1s
- actions.rows(numProj, 2 * numProj - 1) = // last numProj rows
+ actions.rows(numProj, (2 * numProj) - 1) = // last numProj rows
arma::ones< arma::Col<short int> > (numProj); // 1s
@@ -491,7 +470,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// find location and value of smallest element of scores vector
double minscore = scores[0];
size_t minloc = 0;
- for (size_t s = 1; s < 2 * numProj; ++s)
+ for (size_t s = 1; s < (2 * numProj); ++s)
{
if (minscore > scores[s])
{
@@ -518,7 +497,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
double minscore2 = scores[0];
size_t minloc2 = 0;
- for (size_t s = 0; s < 2 * numProj; ++s) // here we can't start from 1
+ for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1
{
if ( minscore2 > scores[s] && s != minloc) //second smallest
{
@@ -530,8 +509,8 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// add or subtract 1 to create second-lowest scoring vector
additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
return;
-
}
+ // General case: more than 2 perturbation vectors require use of minheap.
// sort everything in increasing order
arma::Col<long long unsigned int> sortidx = arma::sort_index(scores);
@@ -561,23 +540,22 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
std::vector< std::vector<size_t> > perturbationSets;
perturbationSets.push_back(Ao); // storage of perturbation sets
-
- // define a priority queue with CompareGreater as a minheap
std::priority_queue<
std::pair<double, size_t>, // contents: pairs of (score, index)
std::vector< // container: vector of pairs
std::pair<double, size_t>
>,
- mlpack::neighbor::CompareGreater // comparator of pairs(compare scores)
+ std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores)
> minHeap; // our minheap
// Start by adding the lowest scoring set to the minheap
- std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
- minHeap.push(pair0);
+ // std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+ // minHeap.push(pair0);
+ minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) );
- double prevScore = 0; // store score of smallest inserted vector (for assert)
// loop invariable: after pvec iterations, additionalProbingBins contains pvec
- // valid codes of the highest-scoring bins
+ // valid codes of the lowest-scoring bins (bins most likely to contain
+ // neighbors of the query).
for (size_t pvec = 0; pvec < T; ++pvec)
{
std::vector<size_t> Ai;
@@ -587,17 +565,17 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
Ai = perturbationSets[ minHeap.top().second ];
minHeap.pop(); // .top() returns, .pop() removes
-
// modify Ai (shift)
std::vector<size_t> As = Ai;
perturbationShift(As);
if ( perturbationValid(As, numProj) )
{
perturbationSets.push_back(As); // add shifted set to sets
- std::pair<double, size_t> shifted(
- perturbationScore(As, scores),
- perturbationSets.size() - 1); // (score, position) pair for shift
- minHeap.push(shifted);
+ minHeap.push(
+ std::make_pair(
+ perturbationScore(As, scores),
+ perturbationSets.size() - 1)
+ );
}
// modify Ai (expand)
@@ -606,20 +584,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
if ( perturbationValid(Ae, numProj) )
{
perturbationSets.push_back(Ae); // add expanded set to sets
- std::pair<double, size_t> expanded(
- perturbationScore(Ae, scores),
- perturbationSets.size() - 1); // (score, position) pair for expand
- minHeap.push(expanded);
+ minHeap.push(
+ std::make_pair(
+ perturbationScore(Ae, scores),
+ perturbationSets.size() - 1)
+ );
}
-
}while (! perturbationValid(Ai, numProj) );//Discard invalid perturbations
- // a valid perturbation must have higher score than previous valid ones,
- // meaning the bin it corresponds to is less likely to hold neighbors
- assert ( perturbationScore(Ai, scores) >= prevScore );
- prevScore = perturbationScore(Ai, scores);
-
// add perturbation vector to probing sequence if valid
for (size_t i = 0; i < Ai.size(); ++i)
additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
@@ -658,18 +631,20 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
allProjInTables = arma::floor(queryCodesNotFloored/hashWidth);
+
// Compute the hash value of each key of the query into a bucket of the
// 'secondHashTable' using the 'secondHashWeights'.
- arma::rowvec hashVec = secondHashWeights.t() * allProjInTables;
-
- // mod and floor hashVec to compute 2nd-level codes
+ arma::Row<size_t> hashVec =
+ arma::conv_to< arma::Row<size_t> >::
+ from( secondHashWeights.t() * allProjInTables ); // typecast to floor
+ // mod to compute 2nd-level codes
for (size_t i = 0; i < hashVec.n_elem; i++)
- hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
+ hashVec[i] = (hashVec[i] % secondHashSize);
Log::Assert(hashVec.n_elem == numTablesToSearch);
// Compute hashVectors of additional probing bins
- arma::mat hashMat;
+ arma::Mat<size_t> hashMat;
if (T > 0)
{
hashMat.zeros(T, numTablesToSearch);
@@ -678,18 +653,17 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
{
// construct this table's probing sequence of length T
arma::mat additionalProbingBins;
- //arma::vec dummyBins;
- //dummyBins.zeros(T, 1);
GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
queryCodesNotFloored.unsafe_col(i),
T,
additionalProbingBins);
// map the probing bin to second hash table bins
- hashMat.col(i) = additionalProbingBins.t() * secondHashWeights;
- //hashMat.col(i) = dummyBins;
+ hashMat.col(i) =
+ arma::conv_to< arma::Col<size_t> >::
+ from(additionalProbingBins.t() * secondHashWeights); // typecast floor
for (size_t p = 0; p < T; ++p)
- hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize);
+ hashMat(p, i) = (hashMat(p, i) % secondHashSize);
}
// top row of hashMat is primary bins for each table
@@ -772,7 +746,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
{
for (size_t p = 0; p < T + 1; ++p)
{
- size_t hashInd = (size_t) hashMat(p, i); // Find the query's bucket.
+ size_t hashInd = hashMat(p, i); // Find the query's bucket.
if (bucketContentSize[hashInd] > 0)
{
@@ -801,7 +775,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
const size_t numTablesToSearch,
- size_t T)
+ const size_t T)
{
// Ensure the dimensionality of the query set is correct.
if (querySet.n_rows != referenceSet->n_rows)
More information about the mlpack-git
mailing list