[mlpack-git] master: Revert "Refactor for faster assembly of secondHashTable." (5b8fdce)

gitdub at mlpack.org gitdub at mlpack.org
Sun Jun 5 01:33:17 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/06cae1319ab1c22b53793c2ceff45605ad0ecc64...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit 5b8fdce471328f722fcd8c0f22a6d995ce22c98b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Jun 5 01:33:17 2016 -0400

    Revert "Refactor for faster assembly of secondHashTable."
    
    This reverts commit 940a2b5766bc613f0db2e55ed4d8498cf287a62c.  I did not mean to
    commit that.


>---------------------------------------------------------------

5b8fdce471328f722fcd8c0f22a6d995ce22c98b
 src/mlpack/methods/lsh/lsh_search.hpp      |  2 +-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 72 ++++++++++++++----------------
 2 files changed, 35 insertions(+), 39 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index a755a99..b42bb7a 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -322,7 +322,7 @@ class LSHSearch
   arma::Col<size_t> bucketContentSize;
 
   //! For a particular hash value, points to the row in secondHashTable
-  //! corresponding to this value. Length secondHashSize.
+  //! corresponding to this value.  Should be secondHashSize.
   arma::Col<size_t> bucketRowInHashTable;
 
   //! The number of distance evaluations.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a141aa2..9ab2067 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -60,7 +60,7 @@ LSHSearch(const arma::mat& referenceSet,
 // Empty constructor.
 template<typename SortPolicy>
 LSHSearch<SortPolicy>::LSHSearch() :
-    referenceSet(new arma::mat()), // Use an empty dataset.
+    referenceSet(new arma::mat()), // empty dataset
     ownsSet(true),
     numProj(0),
     numTables(0),
@@ -153,6 +153,9 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
   bucketRowInHashTable.set_size(secondHashSize);
   bucketRowInHashTable.fill(secondHashSize);
 
+  // Keep track of number of non-empty rows in the 'secondHashTable'.
+  size_t numRowsInTable = 0;
+
   // Step II: The offsets for all projections in all tables.
   // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
   // as randu(numProj, numTables) * hashWidth.
@@ -180,10 +183,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
         "tables provided must be equal to numProj");
   }
 
-  // We will store the second hash vectors in this matrix; the second hash
-  // vector for table i will be held in row i.
-  arma::Mat<size_t> secondHashVectors(numTables, referenceSet.n_cols);
-
   for (size_t i = 0; i < numTables; i++)
   {
     // Step IV: create the 'numProj'-dimensional key for each point in each
@@ -205,36 +204,20 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
 
     // Step V: Putting the points in the 'secondHashTable' by hashing the key.
     // Now we hash every key, point ID to its corresponding bucket.
-    secondHashVectors.row(i) = arma::conv_to<arma::Row<size_t>>::from(
-        secondHashWeights.t() * arma::floor(hashMat));
-  }
+    arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
 
-  // Normalize hashes (take modulus with secondHashSize).
-  secondHashVectors.transform([secondHashSize](size_t val)
-      { return val % secondHashSize; });
+    // This gives us the bucket for the corresponding point ID.
+    for (size_t j = 0; j < secondHashVec.n_elem; j++)
+      secondHashVec[j] = (double) ((size_t) secondHashVec[j] % secondHashSize);
 
-  // Now, using the hash vectors for each table, count the number of rows we
-  // have in the second hash table.
-  arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
-  for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
-    secondHashBinCounts[secondHashVectors[i]]++;
-
-  const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
-  const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
-      bucketSize);
-  secondHashTable.resize(numRowsInTable, maxBucketSize);
+    Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
 
-  // Next we must assign each point in each table to the right second hash
-  // table.
-  size_t currentRow = 0;
-  for (size_t i = 0; i < numTables; ++i)
-  {
     // Insert the point in the corresponding row to its bucket in the
     // 'secondHashTable'.
-    for (size_t j = 0; j < secondHashVectors.n_cols; j++)
+    for (size_t j = 0; j < secondHashVec.n_elem; j++)
     {
       // This is the bucket number.
-      size_t hashInd = (size_t) secondHashVectors(i, j);
+      size_t hashInd = (size_t) secondHashVec[j];
       // The point ID is 'j'.
 
       // If this is currently an empty bucket, start a new row keep track of
@@ -242,24 +225,37 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
       if (bucketContentSize[hashInd] == 0)
       {
         // Start a new row for hash.
-        bucketRowInHashTable[hashInd] = currentRow;
-        bucketContentSize[hashInd] = 1;
-        secondHashTable(currentRow, 0) = j;
-        currentRow++;
+        bucketRowInHashTable[hashInd] = numRowsInTable;
+        secondHashTable(numRowsInTable, 0) = j;
+
+        numRowsInTable++;
       }
-      else if (bucketContentSize[hashInd] < maxBucketSize)
+
+      else
       {
         // If bucket is already present in the 'secondHashTable', find the
         // corresponding row and insert the point ID in this row unless the
-        // bucket is full (in which case we are not inside this else if).
-        secondHashTable(bucketRowInHashTable[hashInd],
-                        bucketContentSize[hashInd]++) = j;
+        // bucket is full, in which case, do nothing.
+        if (bucketContentSize[hashInd] < bucketSize)
+          secondHashTable(bucketRowInHashTable[hashInd],
+                          bucketContentSize[hashInd]) = j;
       }
+
+      // Increment the count of the points in this bucket.
+      if (bucketContentSize[hashInd] < bucketSize)
+        bucketContentSize[hashInd]++;
     } // Loop over all points in the reference set.
   } // Loop over tables.
 
-  Log::Info << "Final hash table size: " << numRowsInTable << " x "
-            << maxBucketSize << "." << std::endl;
+  // Step VI: Condensing the 'secondHashTable'.
+  size_t maxBucketSize = 0;
+  for (size_t i = 0; i < bucketContentSize.n_elem; i++)
+    if (bucketContentSize[i] > maxBucketSize)
+      maxBucketSize = bucketContentSize[i];
+
+  Log::Info << "Final hash table size: (" << numRowsInTable << " x "
+            << maxBucketSize << ")" << std::endl;
+  secondHashTable.resize(numRowsInTable, maxBucketSize);
 }
 
 template<typename SortPolicy>




More information about the mlpack-git mailing list