[mlpack-git] master: Refactor for faster assembly of secondHashTable. (940a2b5)

gitdub at mlpack.org gitdub at mlpack.org
Sun Jun 5 01:31:50 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit 940a2b5766bc613f0db2e55ed4d8498cf287a62c
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Jun 3 20:25:54 2016 -0400

    Refactor for faster assembly of secondHashTable.


>---------------------------------------------------------------

940a2b5766bc613f0db2e55ed4d8498cf287a62c
 src/mlpack/methods/lsh/lsh_search.hpp      |  2 +-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 72 ++++++++++++++++--------------
 2 files changed, 39 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..a755a99 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -322,7 +322,7 @@ class LSHSearch
   arma::Col<size_t> bucketContentSize;
 
   //! For a particular hash value, points to the row in secondHashTable
-  //! corresponding to this value.  Should be secondHashSize.
+  //! corresponding to this value. Length secondHashSize.
   arma::Col<size_t> bucketRowInHashTable;
 
   //! The number of distance evaluations.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 9ab2067..a141aa2 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -60,7 +60,7 @@ LSHSearch(const arma::mat& referenceSet,
 // Empty constructor.
 template<typename SortPolicy>
 LSHSearch<SortPolicy>::LSHSearch() :
-    referenceSet(new arma::mat()), // empty dataset
+    referenceSet(new arma::mat()), // Use an empty dataset.
     ownsSet(true),
     numProj(0),
     numTables(0),
@@ -153,9 +153,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
   bucketRowInHashTable.set_size(secondHashSize);
   bucketRowInHashTable.fill(secondHashSize);
 
-  // Keep track of number of non-empty rows in the 'secondHashTable'.
-  size_t numRowsInTable = 0;
-
   // Step II: The offsets for all projections in all tables.
   // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
   // as randu(numProj, numTables) * hashWidth.
@@ -183,6 +180,10 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
         "tables provided must be equal to numProj");
   }
 
+  // We will store the second hash vectors in this matrix; the second hash
+  // vector for table i will be held in row i.
+  arma::Mat<size_t> secondHashVectors(numTables, referenceSet.n_cols);
+
   for (size_t i = 0; i < numTables; i++)
   {
     // Step IV: create the 'numProj'-dimensional key for each point in each
@@ -204,20 +205,36 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
 
     // Step V: Putting the points in the 'secondHashTable' by hashing the key.
     // Now we hash every key, point ID to its corresponding bucket.
-    arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
+    secondHashVectors.row(i) = arma::conv_to<arma::Row<size_t>>::from(
+        secondHashWeights.t() * arma::floor(hashMat));
+  }
 
-    // This gives us the bucket for the corresponding point ID.
-    for (size_t j = 0; j < secondHashVec.n_elem; j++)
-      secondHashVec[j] = (double) ((size_t) secondHashVec[j] % secondHashSize);
+  // Normalize hashes (take modulus with secondHashSize).
+  secondHashVectors.transform([secondHashSize](size_t val)
+      { return val % secondHashSize; });
 
-    Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+  // Now, using the hash vectors for each table, count the number of rows we
+  // have in the second hash table.
+  arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
+  for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
+    secondHashBinCounts[secondHashVectors[i]]++;
+
+  const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
+  const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
+      bucketSize);
+  secondHashTable.resize(numRowsInTable, maxBucketSize);
 
+  // Next we must assign each point in each table to the right second hash
+  // table.
+  size_t currentRow = 0;
+  for (size_t i = 0; i < numTables; ++i)
+  {
     // Insert the point in the corresponding row to its bucket in the
     // 'secondHashTable'.
-    for (size_t j = 0; j < secondHashVec.n_elem; j++)
+    for (size_t j = 0; j < secondHashVectors.n_cols; j++)
     {
       // This is the bucket number.
-      size_t hashInd = (size_t) secondHashVec[j];
+      size_t hashInd = (size_t) secondHashVectors(i, j);
       // The point ID is 'j'.
 
       // If this is currently an empty bucket, start a new row keep track of
@@ -225,37 +242,24 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
       if (bucketContentSize[hashInd] == 0)
       {
         // Start a new row for hash.
-        bucketRowInHashTable[hashInd] = numRowsInTable;
-        secondHashTable(numRowsInTable, 0) = j;
-
-        numRowsInTable++;
+        bucketRowInHashTable[hashInd] = currentRow;
+        bucketContentSize[hashInd] = 1;
+        secondHashTable(currentRow, 0) = j;
+        currentRow++;
       }
-
-      else
+      else if (bucketContentSize[hashInd] < maxBucketSize)
       {
         // If bucket is already present in the 'secondHashTable', find the
         // corresponding row and insert the point ID in this row unless the
-        // bucket is full, in which case, do nothing.
-        if (bucketContentSize[hashInd] < bucketSize)
-          secondHashTable(bucketRowInHashTable[hashInd],
-                          bucketContentSize[hashInd]) = j;
+        // bucket is full (in which case we are not inside this else if).
+        secondHashTable(bucketRowInHashTable[hashInd],
+                        bucketContentSize[hashInd]++) = j;
       }
-
-      // Increment the count of the points in this bucket.
-      if (bucketContentSize[hashInd] < bucketSize)
-        bucketContentSize[hashInd]++;
     } // Loop over all points in the reference set.
   } // Loop over tables.
 
-  // Step VI: Condensing the 'secondHashTable'.
-  size_t maxBucketSize = 0;
-  for (size_t i = 0; i < bucketContentSize.n_elem; i++)
-    if (bucketContentSize[i] > maxBucketSize)
-      maxBucketSize = bucketContentSize[i];
-
-  Log::Info << "Final hash table size: (" << numRowsInTable << " x "
-            << maxBucketSize << ")" << std::endl;
-  secondHashTable.resize(numRowsInTable, maxBucketSize);
+  Log::Info << "Final hash table size: " << numRowsInTable << " x "
+            << maxBucketSize << "." << std::endl;
 }
 
 template<typename SortPolicy>




More information about the mlpack-git mailing list