[mlpack-git] master: Fixes bug in perturbationValid (89b3c7b)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 23 10:35:57 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc

>---------------------------------------------------------------

commit 89b3c7ba8291a42846e3d3e09981caee98335e78
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Thu Jun 23 15:35:57 2016 +0100

    Fixes bug in perturbationValid


>---------------------------------------------------------------

89b3c7ba8291a42846e3d3e09981caee98335e78
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 36 +++++++++++++-----------------
 src/mlpack/tests/lsh_test.cpp              |  4 ----
 2 files changed, 15 insertions(+), 25 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 6db4e1e..12ccc6d 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -418,15 +418,15 @@ inline double perturbationScore(const arma::Row<char>& A,
 // replaces the largest element of a vector A with (largest element) + 1.
 inline bool perturbationShift(arma::Row<char>& A)
 {
-  size_t max_pos = 0;
-  for (size_t i = 1; i < A.n_elem; ++i)
+  size_t maxPos = 0;
+  for (size_t i = 0; i < A.n_elem; ++i)
     if (A(i) == 1) // marked true
-      max_pos=i;
+      maxPos=i;
   
-  if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector 
+  if ( maxPos + 1 < A.n_elem) // otherwise, this is an invalid vector 
   {
-    A(max_pos) = 0;
-    A(max_pos+1) = 1;
+    A(maxPos) = 0;
+    A(maxPos+1) = 1;
     return true; // valid
   }
   return false; // invalid
@@ -437,14 +437,15 @@ inline bool perturbationShift(arma::Row<char>& A)
 // largest_element is the largest element of A.
 inline bool perturbationExpand(arma::Row<char>& A)
 {
-  size_t max_pos = 0;
-  for (size_t i = 1; i < A.n_elem; ++i)
-    if (A(i) == 1) //marked true
-      max_pos=i;
+  // Find the last '1' in A
+  size_t maxPos = 0;
+  for (size_t i = 0; i < A.n_elem; ++i)
+    if (A(i)) //marked true
+      maxPos = i;
 
-  if ( max_pos + 1 < A.n_elem) // otherwise, this is an invalid vector
+  if ( maxPos + 1 < A.n_elem) // otherwise, this is an invalid vector
   {
-    A(max_pos+1) = 1;
+    A(maxPos+1) = 1;
     return true;
   }
   return false;
@@ -464,7 +465,7 @@ inline bool perturbationValid(const arma::Row<char>& A,
 
   if (A.n_elem > 2 * numProj)
   {
-    // Log::Assert(1 == 2);
+    Log::Assert(1 == 2);
     return false; // This should never happen
   }
 
@@ -480,12 +481,6 @@ inline bool perturbationValid(const arma::Row<char>& A,
       check[i % numProj] = true;
     else
       return false; // If dimension was seen before, set is not valid.
-
-
-    if (check[A[i] % numProj ] == 0)
-      check[A[i] % numProj ] = 1;
-    else
-      return false;
   }
 
   // If we didn't fail, set is valid.
@@ -748,7 +743,6 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
     for (size_t i = 0; i < Ai.n_elem; ++i)
       additionalProbingBins(positions(i), pvec) 
           += Ai(i) ? actions(i) : 0; // if A(i) marked, add action to probing vector
-
   }
 }
 
@@ -902,7 +896,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
         if (tableRow < secondHashSize)
          // Store all secondHashTable points in the candidates set.
          for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
-           refPointsConsideredSmall(start++) = secondHashTable[tableRow][j];
+           refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
       }
     }
 
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 63ff2c6..d576387 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -601,10 +601,6 @@ BOOST_AUTO_TEST_CASE(MultiprobeDeterministicTest)
   // Searching with 3 additional probing bins should find neighbors
   lshTest.Search(q1, k, neighbors, distances, 0, 3);
   BOOST_REQUIRE( arma::all(neighbors.col(0) == N || neighbors.col(0) < 10) );
-
-  // Demand that we do find at least 1 neighbor (not just Ns, but actuall
-  // values)
-  // BOOST_REQUIRE_GE( arma::find(neighbors.col(0) < 10).n_elem, 1);
 }
 
 BOOST_AUTO_TEST_CASE(LSHTrainTest)




More information about the mlpack-git mailing list