[mlpack-git] master: Uses arma::Row<char> instead of std::vector for perturbation sets (75dead3)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jun 23 07:23:57 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc
>---------------------------------------------------------------
commit 75dead3f4f6c20af63eb28e800abdfaf22250484
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Thu Jun 23 12:23:57 2016 +0100
Uses arma::Row<char> instead of std::vector for perturbation sets
>---------------------------------------------------------------
75dead3f4f6c20af63eb28e800abdfaf22250484
src/mlpack/methods/lsh/lsh_search_impl.hpp | 159 +++++++++++++++++++++++++++++
src/mlpack/tests/lsh_test.cpp | 4 +
2 files changed, 163 insertions(+)
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b1cc112..6db4e1e 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -333,6 +333,7 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceIndex, distance);
}
+/*
//Returns the score of a perturbation vector generated by perturbation set A.
//The score of a pertubation set (vector) is the sum of scores of the
//participating actions.
@@ -399,7 +400,97 @@ inline bool perturbationValid(const std::vector<size_t>& A,
}
return true;
}
+*/
+//Returns the score of a perturbation vector generated by perturbation set A.
+//The score of a pertubation set (vector) is the sum of scores of the
+//participating actions.
+inline double perturbationScore(const arma::Row<char>& A,
+ const arma::vec& scores)
+{
+ double score = 0.0;
+ for (size_t i = 0; i < A.n_elem; ++i)
+ score += A(i) ? scores(i) : 0; //add scores of non-zero indices
+ return score;
+}
+
+// Inline function used by GetAdditionalProbingBins. The vector shift operation
+// replaces the largest element of a vector A with (largest element) + 1.
+inline bool perturbationShift(arma::Row<char>& A)
+{
+ size_t max_pos = 0;
+ for (size_t i = 1; i < A.n_elem; ++i)
+ if (A(i) == 1) // marked true
+ max_pos=i;
+
+ if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector
+ {
+ A(max_pos) = 0;
+ A(max_pos+1) = 1;
+ return true; // valid
+ }
+ return false; // invalid
+}
+
+// Inline function used by GetAdditionalProbingBins. The vector expansion
+// operation adds the element [1 + (largest_element)] to a vector A, where
+// largest_element is the largest element of A.
+inline bool perturbationExpand(arma::Row<char>& A)
+{
+ size_t max_pos = 0;
+ for (size_t i = 1; i < A.n_elem; ++i)
+ if (A(i) == 1) //marked true
+ max_pos=i;
+
+ if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector
+ {
+ A(max_pos+1) = 1;
+ return true;
+ }
+ return false;
+
+}
+
+// Return true if perturbation set A is valid. A perturbation set is invalid if
+// it contains two (or more) actions for the same dimension or dimensions that
+// are larger than the queryCode's dimensions.
+inline bool perturbationValid(const arma::Row<char>& A,
+ const size_t numProj)
+{
+ // Stack allocation and initialization to 0 (bool check[numProj] = {0}) made
+ // some compilers complain, and std::vector might even be compressed (depends
+ // on implementation) so this saves some space.
+ std::vector<bool> check(numProj);
+
+ if (A.n_elem > 2 * numProj)
+ {
+ // Log::Assert(1 == 2);
+ return false; // This should never happen
+ }
+
+ // Check that we only see each dimension once. If not, vector is not valid.
+ for (size_t i = 0; i < A.n_elem; ++i)
+ {
+ // Only check dimensions that were included.
+ if (!A(i))
+ continue;
+
+ // If dimesnion is unseen thus far, mark it as seen.
+ if ( check[i % numProj] == false )
+ check[i % numProj] = true;
+ else
+ return false; // If dimension was seen before, set is not valid.
+
+
+ if (check[A[i] % numProj ] == 0)
+ check[A[i] % numProj ] = 1;
+ else
+ return false;
+ }
+
+ // If we didn't fail, set is valid.
+ return true;
+}
// Compute additional probing bins for a query
template<typename SortPolicy>
@@ -527,6 +618,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// Transform perturbation set to perturbation vector by setting the
// dimensions specified by the set to queryCode+action (action is {-1, 1}).
+ /*
std::vector<size_t> Ao;
Ao.push_back(0); // initial perturbation holds smallest score (0 if sorted)
@@ -591,6 +683,73 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
}
+ */
+
+ // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
+ // included in a given perturbation vector. Other spaces are 0.
+ arma::Row<char> Ao(2 * numProj, arma::fill::zeros);
+ Ao(0) = 1; // Smallest vector includes only smallest score.
+
+ std::vector< arma::Row<char> > perturbationSets;
+ perturbationSets.push_back(Ao); // storage of perturbation sets
+
+ std::priority_queue<
+ std::pair<double, size_t>, // contents: pairs of (score, index)
+ std::vector< // container: vector of pairs
+ std::pair<double, size_t>
+ >,
+ std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores)
+ > minHeap; // our minheap
+
+ // Start by adding the lowest scoring set to the minheap
+ // std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+ // minHeap.push(pair0);
+ minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) );
+
+ // loop invariable: after pvec iterations, additionalProbingBins contains pvec
+ // valid codes of the lowest-scoring bins (bins most likely to contain
+ // neighbors of the query).
+ for (size_t pvec = 0; pvec < T; ++pvec)
+ {
+ arma::Row<char> Ai;
+ do
+ {
+ // get the perturbation set corresponding to the minimum score
+ Ai = perturbationSets[ minHeap.top().second ];
+ minHeap.pop(); // .top() returns, .pop() removes
+
+ // modify Ai (shift)
+ arma::Row<char> As = Ai;
+ if ( perturbationShift(As) && perturbationValid(As, numProj) ) // Don't add invalid sets.
+ {
+ perturbationSets.push_back(As); // add shifted set to sets
+ minHeap.push(
+ std::make_pair(
+ perturbationScore(As, scores),
+ perturbationSets.size() - 1)
+ );
+ }
+
+ // modify Ai (expand)
+ arma::Row<char> Ae = Ai;
+ if ( perturbationExpand(Ae) && perturbationValid(Ae, numProj) ) // Don't add invalid sets.
+ {
+ perturbationSets.push_back(Ae); // add expanded set to sets
+ minHeap.push(
+ std::make_pair(
+ perturbationScore(Ae, scores),
+ perturbationSets.size() - 1)
+ );
+ }
+
+ }while (! perturbationValid(Ai, numProj) );//Discard invalid perturbations
+
+ // add perturbation vector to probing sequence if valid
+ for (size_t i = 0; i < Ai.n_elem; ++i)
+ additionalProbingBins(positions(i), pvec)
+ += Ai(i) ? actions(i) : 0; // if A(i) marked, add action to probing vector
+
+ }
}
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index d576387..63ff2c6 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -601,6 +601,10 @@ BOOST_AUTO_TEST_CASE(MultiprobeDeterministicTest)
// Searching with 3 additional probing bins should find neighbors
lshTest.Search(q1, k, neighbors, distances, 0, 3);
BOOST_REQUIRE( arma::all(neighbors.col(0) == N || neighbors.col(0) < 10) );
+
+ // Demand that we do find at least 1 neighbor (not just Ns, but actuall
+ // values)
+ // BOOST_REQUIRE_GE( arma::find(neighbors.col(0) < 10).n_elem, 1);
}
BOOST_AUTO_TEST_CASE(LSHTrainTest)
More information about the mlpack-git
mailing list