[mlpack-git] master: Fixes style issues, optimizes code a bit (ae81ee5)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 22 11:28:58 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc

>---------------------------------------------------------------

commit ae81ee58ac0db373cc7d0f5f95c246f1f258e7e9
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Wed Jun 22 16:28:58 2016 +0100

    Fixes style issues, optimizes code a bit


>---------------------------------------------------------------

ae81ee58ac0db373cc7d0f5f95c246f1f258e7e9
 src/mlpack/methods/lsh/lsh_search.hpp      |   4 +-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 122 ++++++++++++-----------------
 2 files changed, 51 insertions(+), 75 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 364ef19..b21b0d1 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -154,13 +154,15 @@ class LSHSearch
    *     available without having to build hashing for every table size.
    *     By default, this is set to zero in which case all tables are
    *     considered.
+   * @param T The number of additional probing bins to examine with multiprobe
+   *     LSH. If T = 0, classic single-probe LSH is run (default).
    */
   void Search(const arma::mat& querySet,
               const size_t k,
               arma::Mat<size_t>& resultingNeighbors,
               arma::mat& distances,
               const size_t numTablesToSearch = 0,
-              size_t T = 0);
+              const size_t T = 0);
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 36e4afa..a89ea36 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -340,25 +340,11 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         referenceIndex, distance);
 }
 
-
-// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins.
-class CompareGreater
-{
-  public:
-    bool operator()(
-        std::pair<double, size_t> p1,
-        std::pair<double, size_t> p2){
-      //only compare the double values
-      return p1.first > p2.first;
-    }
-};
-
 //Returns the score of a perturbation vector generated by perturbation set A.
 //The score of a pertubation set (vector) is the sum of scores of the
 //participating actions.
-inline double perturbationScore(
-              const std::vector<size_t> &A,
-              const arma::vec &scores)
+inline double perturbationScore(const std::vector<size_t>& A,
+                                const arma::vec& scores)
 {
   double score = 0.0;
   for (size_t i = 0; i < A.size(); ++i)
@@ -368,7 +354,7 @@ inline double perturbationScore(
 
 // Inline function used by GetAdditionalProbingBins. The vector shift operation
 // replaces the largest element of a vector A with (largest element) + 1.
-inline void perturbationShift(std::vector<size_t> &A)
+inline void perturbationShift(std::vector<size_t>& A)
 {
   size_t max_pos = 0;
   size_t max = A[0];
@@ -386,45 +372,38 @@ inline void perturbationShift(std::vector<size_t> &A)
 // Inline function used by GetAdditionalProbingBins. The vector expansion
 // operation adds the element [1 + (largest_element)] to a vector A, where
 // largest_element is the largest element of A.
-inline void perturbationExpand(std::vector<size_t> &A)
+inline void perturbationExpand(std::vector<size_t>& A)
 {
   size_t max = A[0];
   for (size_t i = 1; i < A.size(); ++i)
     if (A[i] > max)
       max = A[i];
-  A.push_back(max+1);
+  A.push_back(max + 1);
 }
 
 // Return true if perturbation set A is valid. A perturbation set is invalid if
 // it contains two (or more) actions for the same dimension or dimensions that
 // are larger than the queryCode's dimensions.
-inline bool perturbationValid(
-    const std::vector<size_t> &A,
-    const size_t numProj)
+inline bool perturbationValid(const std::vector<size_t>& A,
+                              const size_t numProj)
 {
   // Stack allocation and initialization to 0 (bool check[numProj] = {0}) made
-  // some compilers complain. We use new just to be safe.
-  bool *check = new bool[numProj]();
+  // some compilers complain, and std::vector might even be compressed (depends
+  // on implementation) so this saves some space.
+  std::vector<bool> check(numProj);
 
   for (size_t i = 0; i < A.size(); ++i)
   {
     // Check that we only use valid dimensions. If not, vector is not valid.
-    if ( A[i] >= 2*numProj)
-    {
-      delete []check;
+    if (A[i] >= 2 * numProj)
       return false;
-    }
 
     // Check that we only see each dimension once. If not, vector is not valid.
     if (check[A[i] % numProj ] == 0)
       check[A[i] % numProj ] = 1;
     else
-    {
-      delete []check;
       return false;
-    }
   }
-  delete []check;
   return true;
 }
 
@@ -432,10 +411,10 @@ inline bool perturbationValid(
 // Compute additional probing bins for a query
 template<typename SortPolicy>
 void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
-                            const arma::vec &queryCode,
-                            const arma::vec &queryCodeNotFloored,
+                            const arma::vec& queryCode,
+                            const arma::vec& queryCodeNotFloored,
                             const size_t T,
-                            arma::mat &additionalProbingBins) const
+                            arma::mat& additionalProbingBins) const
 {
 
   // No additional bins requested. Our work is done.
@@ -461,7 +440,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
   // calculate scores = distances^2
   arma::vec scores(2 * numProj);
   scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
-  scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2);
+  scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
 
   // actions vector shows what transformation to apply to a coordinate
   arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
@@ -469,7 +448,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
   actions.rows(0, numProj - 1) = // first numProj rows
     -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
 
-  actions.rows(numProj, 2 * numProj - 1) = // last numProj rows
+  actions.rows(numProj, (2 * numProj) - 1) = // last numProj rows
     arma::ones< arma::Col<short int> > (numProj); // 1s
 
 
@@ -491,7 +470,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
     // find location and value of smallest element of scores vector
     double minscore = scores[0];
     size_t minloc = 0;
-    for (size_t s = 1; s < 2 * numProj; ++s)
+    for (size_t s = 1; s < (2 * numProj); ++s)
     {
       if (minscore > scores[s])
       {
@@ -518,7 +497,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
     
     double minscore2 = scores[0];
     size_t minloc2 = 0;
-    for (size_t s = 0; s < 2 * numProj; ++s) // here we can't start from 1
+    for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1
     {
       if ( minscore2 > scores[s] && s != minloc) //second smallest
       {
@@ -530,8 +509,8 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
     // add or subtract 1 to create second-lowest scoring vector
     additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
     return;
-    
   }
+  // General case: more than 2 perturbation vectors require use of minheap.
 
   // sort everything in increasing order
   arma::Col<long long unsigned int> sortidx = arma::sort_index(scores);
@@ -561,23 +540,22 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
   std::vector< std::vector<size_t> > perturbationSets;
   perturbationSets.push_back(Ao); // storage of perturbation sets
 
-
-  // define a priority queue with CompareGreater as a minheap
   std::priority_queue<
     std::pair<double, size_t>,        // contents: pairs of (score, index)
     std::vector<                      // container: vector of pairs
       std::pair<double, size_t>
       >,
-    mlpack::neighbor::CompareGreater // comparator of pairs(compare scores)
+    std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores)
   > minHeap; // our minheap
 
   // Start by adding the lowest scoring set to the minheap
-  std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
-  minHeap.push(pair0);
+  // std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
+  // minHeap.push(pair0);
+  minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) );
 
-  double prevScore = 0; // store score of smallest inserted vector (for assert)
   // loop invariable: after pvec iterations, additionalProbingBins contains pvec
-  // valid codes of the highest-scoring bins
+  // valid codes of the lowest-scoring bins (bins most likely to contain
+  // neighbors of the query).
   for (size_t pvec = 0; pvec < T; ++pvec)
   {
     std::vector<size_t> Ai;
@@ -587,17 +565,17 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
       Ai = perturbationSets[ minHeap.top().second ];
       minHeap.pop(); // .top() returns, .pop() removes
 
-
       // modify Ai (shift)
       std::vector<size_t> As = Ai;
       perturbationShift(As);
       if ( perturbationValid(As, numProj) )
       {
         perturbationSets.push_back(As); // add shifted set to sets
-        std::pair<double, size_t> shifted(
-            perturbationScore(As, scores),
-            perturbationSets.size() - 1); // (score, position) pair for shift
-        minHeap.push(shifted);
+        minHeap.push( 
+            std::make_pair(
+              perturbationScore(As, scores), 
+              perturbationSets.size() - 1)
+            );
       }
 
       // modify Ai (expand)
@@ -606,20 +584,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
       if ( perturbationValid(Ae, numProj) )
       {
         perturbationSets.push_back(Ae); // add expanded set to sets
-        std::pair<double, size_t> expanded(
-            perturbationScore(Ae, scores),
-            perturbationSets.size() - 1); // (score, position) pair for expand
-        minHeap.push(expanded);
+        minHeap.push(
+            std::make_pair(
+              perturbationScore(Ae, scores),
+              perturbationSets.size() - 1)
+            );
       }
 
-
     }while (! perturbationValid(Ai, numProj)  );//Discard invalid perturbations
 
-    // a valid perturbation must have higher score than previous valid ones,
-    // meaning the bin it corresponds to is less likely to hold neighbors
-    assert ( perturbationScore(Ai, scores) >= prevScore );
-    prevScore = perturbationScore(Ai, scores);
-
     // add perturbation vector to probing sequence if valid
     for (size_t i = 0; i < Ai.size(); ++i)
       additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
@@ -658,18 +631,20 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
   queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables = arma::floor(queryCodesNotFloored/hashWidth);
 
+
   // Compute the hash value of each key of the query into a bucket of the
   // 'secondHashTable' using the 'secondHashWeights'.
-  arma::rowvec hashVec = secondHashWeights.t() * allProjInTables;
-
-  // mod and floor hashVec to compute 2nd-level codes
+  arma::Row<size_t> hashVec = 
+    arma::conv_to< arma::Row<size_t> >::
+    from( secondHashWeights.t() * allProjInTables ); // typecast to floor
+  // mod to compute 2nd-level codes
   for (size_t i = 0; i < hashVec.n_elem; i++)
-    hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
+    hashVec[i] = (hashVec[i] % secondHashSize);
 
   Log::Assert(hashVec.n_elem == numTablesToSearch);
 
   // Compute hashVectors of additional probing bins
-  arma::mat hashMat;
+  arma::Mat<size_t> hashMat;
   if (T > 0)
   {
     hashMat.zeros(T, numTablesToSearch);
@@ -678,18 +653,17 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     {
       // construct this table's probing sequence of length T
       arma::mat additionalProbingBins;
-      //arma::vec dummyBins;
-      //dummyBins.zeros(T, 1);
       GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
                                 queryCodesNotFloored.unsafe_col(i),
                                 T,
                                 additionalProbingBins);
 
       // map the probing bin to second hash table bins
-      hashMat.col(i) = additionalProbingBins.t() * secondHashWeights;
-      //hashMat.col(i) = dummyBins;
+      hashMat.col(i) = 
+        arma::conv_to< arma::Col<size_t> >::
+        from(additionalProbingBins.t() * secondHashWeights); // typecast floor
       for (size_t p = 0; p < T; ++p)
-        hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize);
+        hashMat(p, i) = (hashMat(p, i) % secondHashSize);
     }
 
     // top row of hashMat is primary bins for each table
@@ -772,7 +746,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
     {
       for (size_t p = 0; p < T + 1; ++p)
       {
-        size_t hashInd = (size_t) hashMat(p, i); // Find the query's bucket.
+        size_t hashInd =  hashMat(p, i); // Find the query's bucket.
 
         if (bucketContentSize[hashInd] > 0)
         {
@@ -801,7 +775,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
                                    arma::Mat<size_t>& resultingNeighbors,
                                    arma::mat& distances,
                                    const size_t numTablesToSearch,
-                                   size_t T)
+                                   const size_t T)
 {
   // Ensure the dimensionality of the query set is correct.
   if (querySet.n_rows != referenceSet->n_rows)




More information about the mlpack-git mailing list