[mlpack-git] master: Uses arma::Row<char> instead of std::vector for perturbation sets (75dead3)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 23 07:23:57 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc

>---------------------------------------------------------------

commit 75dead3f4f6c20af63eb28e800abdfaf22250484
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Thu Jun 23 12:23:57 2016 +0100

    Uses arma::Row<char> instead of std::vector for perturbation sets


>---------------------------------------------------------------

75dead3f4f6c20af63eb28e800abdfaf22250484
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 159 +++++++++++++++++++++++++++++
 src/mlpack/tests/lsh_test.cpp              |   4 +
 2 files changed, 163 insertions(+)

diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b1cc112..6db4e1e 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -333,6 +333,7 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         referenceIndex, distance);
 }
 
+/*
 //Returns the score of a perturbation vector generated by perturbation set A.
 //The score of a pertubation set (vector) is the sum of scores of the
 //participating actions.
@@ -399,7 +400,97 @@ inline bool perturbationValid(const std::vector<size_t>& A,
   }
   return true;
 }
+*/
 
+//Returns the score of a perturbation vector generated by perturbation set A.
+//The score of a pertubation set (vector) is the sum of scores of the
+//participating actions.
+inline double perturbationScore(const arma::Row<char>& A,
+                                const arma::vec& scores)
+{
+  double score = 0.0;
+  for (size_t i = 0; i < A.n_elem; ++i)
+    score += A(i) ? scores(i) : 0; //add scores of non-zero indices
+  return score;
+}
+
+// Inline function used by GetAdditionalProbingBins. The vector shift operation
+// replaces the largest element of a vector A with (largest element) + 1.
+inline bool perturbationShift(arma::Row<char>& A)
+{
+  size_t max_pos = 0;
+  for (size_t i = 1; i < A.n_elem; ++i)
+    if (A(i) == 1) // marked true
+      max_pos=i;
+  
+  if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector 
+  {
+    A(max_pos) = 0;
+    A(max_pos+1) = 1;
+    return true; // valid
+  }
+  return false; // invalid
+}
+
+// Inline function used by GetAdditionalProbingBins. The vector expansion
+// operation adds the element [1 + (largest_element)] to a vector A, where
+// largest_element is the largest element of A.
+inline bool perturbationExpand(arma::Row<char>& A)
+{
+  size_t max_pos = 0;
+  for (size_t i = 1; i < A.n_elem; ++i)
+    if (A(i) == 1) //marked true
+      max_pos=i;
+
+  if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector
+  {
+    A(max_pos+1) = 1;
+    return true;
+  }
+  return false;
+
+}
+
+// Return true if perturbation set A is valid. A perturbation set is invalid if
+// it contains two (or more) actions for the same dimension or dimensions that
+// are larger than the queryCode's dimensions.
+inline bool perturbationValid(const arma::Row<char>& A,
+                              const size_t numProj)
+{
+  // Stack allocation and initialization to 0 (bool check[numProj] = {0}) made
+  // some compilers complain, and std::vector might even be compressed (depends
+  // on implementation) so this saves some space.
+  std::vector<bool> check(numProj);
+
+  if (A.n_elem > 2 * numProj)
+  {
+    // Log::Assert(1 == 2);
+    return false; // This should never happen
+  }
+
+  // Check that we only see each dimension once. If not, vector is not valid.
+  for (size_t i = 0; i < A.n_elem; ++i)
+  {
+    // Only check dimensions that were included.
+    if (!A(i))
+      continue;
+
+    // If dimesnion is unseen thus far, mark it as seen.
+    if ( check[i % numProj] == false )
+      check[i % numProj] = true;
+    else
+      return false; // If dimension was seen before, set is not valid.
+
+
+    if (check[A[i] % numProj ] == 0)
+      check[A[i] % numProj ] = 1;
+    else
+      return false;
+  }
+
+  // If we didn't fail, set is valid.
+  return true;
+}
 
 // Compute additional probing bins for a query
 template<typename SortPolicy>
@@ -527,6 +618,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
   // Transform perturbation set to perturbation vector by setting the
   // dimensions specified by the set to queryCode+action (action is {-1, 1}).
 
+  /*
   std::vector<size_t> Ao;
   Ao.push_back(0); // initial perturbation holds smallest score (0 if sorted)
 
@@ -591,6 +683,73 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
       additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
 
   }
+  */
+
+  // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
+  // included in a given perturbation vector. Other spaces are 0.
+  arma::Row<char> Ao(2 * numProj, arma::fill::zeros);
+  Ao(0) = 1; // Smallest vector includes only smallest score.
+
+  std::vector< arma::Row<char> > perturbationSets;
+  perturbationSets.push_back(Ao); // storage of perturbation sets
+
+  std::priority_queue<
+    std::pair<double, size_t>,        // contents: pairs of (score, index)
+    std::vector<                      // container: vector of pairs
+      std::pair<double, size_t>
+      >,
+    std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores)
+  > minHeap; // our minheap
+
+  // Start by adding the lowest scoring set to the minheap
+  // std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+  // minHeap.push(pair0);
+  minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) );
+
+  // loop invariable: after pvec iterations, additionalProbingBins contains pvec
+  // valid codes of the lowest-scoring bins (bins most likely to contain
+  // neighbors of the query).
+  for (size_t pvec = 0; pvec < T; ++pvec)
+  {
+    arma::Row<char> Ai;
+    do
+    {
+      // get the perturbation set corresponding to the minimum score
+      Ai = perturbationSets[ minHeap.top().second ];
+      minHeap.pop(); // .top() returns, .pop() removes
+
+      // modify Ai (shift)
+      arma::Row<char> As = Ai;
+      if ( perturbationShift(As) && perturbationValid(As, numProj) ) // Don't add invalid sets.
+      {
+        perturbationSets.push_back(As); // add shifted set to sets
+        minHeap.push( 
+            std::make_pair(
+              perturbationScore(As, scores), 
+              perturbationSets.size() - 1)
+            );
+      }
+
+      // modify Ai (expand)
+      arma::Row<char> Ae = Ai;
+      if ( perturbationExpand(Ae) && perturbationValid(Ae, numProj) ) // Don't add invalid sets.
+      {
+        perturbationSets.push_back(Ae); // add expanded set to sets
+        minHeap.push(
+            std::make_pair(
+              perturbationScore(Ae, scores),
+              perturbationSets.size() - 1)
+            );
+      }
+
+    }while (! perturbationValid(Ai, numProj)  );//Discard invalid perturbations
+
+    // add perturbation vector to probing sequence if valid
+    for (size_t i = 0; i < Ai.n_elem; ++i)
+      additionalProbingBins(positions(i), pvec) 
+          += Ai(i) ? actions(i) : 0; // if A(i) marked, add action to probing vector
+
+  }
 }
 
 
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index d576387..63ff2c6 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -601,6 +601,10 @@ BOOST_AUTO_TEST_CASE(MultiprobeDeterministicTest)
   // Searching with 3 additional probing bins should find neighbors
   lshTest.Search(q1, k, neighbors, distances, 0, 3);
   BOOST_REQUIRE( arma::all(neighbors.col(0) == N || neighbors.col(0) < 10) );
+
+  // Demand that we do find at least 1 neighbor (not just Ns, but actuall
+  // values)
+  // BOOST_REQUIRE_GE( arma::find(neighbors.col(0) < 10).n_elem, 1);
 }
 
 BOOST_AUTO_TEST_CASE(LSHTrainTest)




More information about the mlpack-git mailing list