[mlpack-git] master: Switch secondHashTable to vector<Col<size_t>>. (bd65c87)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 8 14:33:12 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8d7e5db0bed8fc236407bdc5dee00d716d72a5ab...ae6c9e63b56c1ed1faa9aef9352854bbeb826a2f

>---------------------------------------------------------------

commit bd65c877e6b62bdcbc1f1fb46100587e54e74cec
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jun 8 14:33:12 2016 -0400

    Switch secondHashTable to vector<Col<size_t>>.
    
    This should provide a good amount of speedup, and also save RAM.


>---------------------------------------------------------------

bd65c877e6b62bdcbc1f1fb46100587e54e74cec
 src/mlpack/methods/lsh/lsh_search.hpp      |   8 +-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 144 ++++++++++++++++++++---------
 src/mlpack/tests/serialization_test.cpp    |  12 ++-
 3 files changed, 116 insertions(+), 48 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index a755a99..7cbe1e6 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -197,7 +197,8 @@ class LSHSearch
   size_t BucketSize() const { return bucketSize; }
 
   //! Get the second hash table.
-  const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
+  const std::vector<arma::Col<size_t>>& SecondHashTable() const
+      { return secondHashTable; }
 
   //! Get the projection tables.
   const arma::cube& Projections() { return projections; }
@@ -314,8 +315,9 @@ class LSHSearch
   //! The bucket size of the second hash.
   size_t bucketSize;
 
-  //! The final hash table; should be (< secondHashSize) x bucketSize.
-  arma::Mat<size_t> secondHashTable;
+  //! The final hash table; should be (< secondHashSize) vectors each with
+  //! (<= bucketSize) elements.
+  std::vector<arma::Col<size_t>> secondHashTable;
 
   //! The number of elements present in each hash bucket; should be
   //! secondHashSize.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a141aa2..bd022c8 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -131,21 +131,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
   secondHashWeights = arma::floor(arma::randu(numProj) *
                                   (double) secondHashSize);
 
-  // The 'secondHashTable' is initially an empty matrix of size
-  // ('secondHashSize' x 'bucketSize'). But by only filling the buckets as
-  // points land in them allows us to shrink the size of the 'secondHashTable'
-  // at the end of the hashing.
-
-  // Fill the second hash table n = referenceSet.n_cols.  This is because no
-  // point has index 'n' so the presence of this in the bucket denotes that
-  // there are no more points in this bucket.
-  secondHashTable.set_size(secondHashSize, bucketSize);
-  secondHashTable.fill(referenceSet.n_cols);
-
-  // Keep track of the size of each bucket in the hash.  At the end of hashing
-  // most buckets will be empty.
-  bucketContentSize.zeros(secondHashSize);
-
   // Instead of putting the points in the row corresponding to the bucket, we
   // chose the next empty row and keep track of the row in which the bucket
   // lies. This allows us to stack together and slice out the empty buckets at
@@ -219,10 +204,13 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
   for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
     secondHashBinCounts[secondHashVectors[i]]++;
 
+  // Enforce the maximum bucket size.
+  secondHashBinCounts.transform([bucketSize](size_t val)
+      { return std::min(val, bucketSize); });
+
   const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
-  const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
-      bucketSize);
-  secondHashTable.resize(numRowsInTable, maxBucketSize);
+  bucketContentSize.zeros(numRowsInTable);
+  secondHashTable.resize(numRowsInTable);
 
   // Next we must assign each point in each table to the right second hash
   // table.
@@ -239,27 +227,26 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
 
       // If this is currently an empty bucket, start a new row keep track of
       // which row corresponds to the bucket.
-      if (bucketContentSize[hashInd] == 0)
+      const size_t maxSize = secondHashBinCounts[hashInd];
+      if (bucketRowInHashTable[hashInd] == secondHashSize)
       {
-        // Start a new row for hash.
         bucketRowInHashTable[hashInd] = currentRow;
-        bucketContentSize[hashInd] = 1;
-        secondHashTable(currentRow, 0) = j;
+        secondHashTable[currentRow].set_size(maxSize);
         currentRow++;
       }
-      else if (bucketContentSize[hashInd] < maxBucketSize)
-      {
-        // If bucket is already present in the 'secondHashTable', find the
-        // corresponding row and insert the point ID in this row unless the
-        // bucket is full (in which case we are not inside this else if).
-        secondHashTable(bucketRowInHashTable[hashInd],
-                        bucketContentSize[hashInd]++) = j;
-      }
+
+      // If this vector in the hash table is not full, add the point.
+      const size_t index = bucketRowInHashTable[hashInd];
+      if (bucketContentSize[index] < maxSize)
+        secondHashTable[index](bucketContentSize[index]++) = j;
+
     } // Loop over all points in the reference set.
   } // Loop over tables.
 
-  Log::Info << "Final hash table size: " << numRowsInTable << " x "
-            << maxBucketSize << "." << std::endl;
+  Log::Info << "Final hash table size: " << numRowsInTable << " rows, with a "
+            << "maximum length of " << arma::max(secondHashBinCounts) << ", "
+            << "totaling " << arma::accu(secondHashBinCounts) << " elements."
+            << std::endl;
 }
 
 template<typename SortPolicy>
@@ -388,17 +375,14 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
 
   for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
   {
-    size_t hashInd = (size_t) hashVec[i];
+    const size_t hashInd = (size_t) hashVec[i];
+    const size_t tableRow = bucketRowInHashTable[hashInd];
 
-    if (bucketContentSize[hashInd] > 0)
+    if ((tableRow != secondHashSize) && (bucketContentSize[tableRow] > 0))
     {
       // Pick the indices in the bucket corresponding to 'hashInd'.
-      size_t tableRow = bucketRowInHashTable[hashInd];
-      assert(tableRow < secondHashSize);
-      assert(tableRow < secondHashTable.n_rows);
-
-      for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
-        refPointsConsidered[secondHashTable(tableRow, j)]++;
+      for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
+        refPointsConsidered[secondHashTable[tableRow](j)]++;
     }
   }
 
@@ -540,7 +524,7 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
   if (Archive::is_loading::value)
     projections.reset();
 
-  // Backward compatibility: older version of LSHSearch stored the projection
+  // Backward compatibility: older versions of LSHSearch stored the projection
   // tables in a std::vector<arma::mat>.
   if (version == 0)
   {
@@ -561,8 +545,82 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
   ar & CreateNVP(secondHashSize, "secondHashSize");
   ar & CreateNVP(secondHashWeights, "secondHashWeights");
   ar & CreateNVP(bucketSize, "bucketSize");
-  ar & CreateNVP(secondHashTable, "secondHashTable");
-  ar & CreateNVP(bucketContentSize, "bucketContentSize");
+  // needs specific handling for new version
+
+  // Backward compatibility: in older versions of LSHSearch, the secondHashTable
+  // was stored as an arma::Mat<size_t>.  So we need to properly load that, then
+  // prune it down to size.
+  if (version == 0)
+  {
+    arma::Mat<size_t> tmpSecondHashTable;
+    ar & CreateNVP(tmpSecondHashTable, "secondHashTable");
+
+    secondHashTable.resize(tmpSecondHashTable.n_cols);
+    for (size_t i = 0; i < tmpSecondHashTable.n_cols; ++i)
+    {
+      // Find length of each column.  We know we are at the end of the list when
+      // the value referenceSet->n_cols is seen.
+      size_t len = 0;
+      for ( ; len < tmpSecondHashTable.n_rows; ++len)
+        if (tmpSecondHashTable(len, i) == referenceSet->n_cols)
+          break;
+
+      // Set the size of the new column correctly.
+      secondHashTable[i].set_size(len);
+      for (size_t j = 0; j < len; ++j)
+        secondHashTable[i](j) = tmpSecondHashTable(j, i);
+    }
+  }
+  else
+  {
+    size_t tables;
+    if (Archive::is_saving::value)
+      tables = secondHashTable.size();
+    ar & CreateNVP(tables, "numSecondHashTables");
+
+    // Set size of second hash table if needed.
+    if (Archive::is_loading::value)
+    {
+      secondHashTable.clear();
+      secondHashTable.resize(tables);
+    }
+
+    for (size_t i = 0; i < secondHashTable.size(); ++i)
+    {
+      std::ostringstream oss;
+      oss << "secondHashTable" << i;
+      ar & CreateNVP(secondHashTable[i], oss.str());
+    }
+  }
+
+  // Backward compatibility: old versions of LSHSearch held bucketContentSize
+  // for all possible buckets (of size secondHashSize), but now we hold a
+  // compressed representation.
+  if (version == 0)
+  {
+    // The vector was stored in the old uncompressed form.  So we need to shrink
+    // it.
+    arma::Col<size_t> tmpBucketContentSize;
+    ar & CreateNVP(tmpBucketContentSize, "bucketContentSize");
+
+    // Compress into a smaller vector by just dropping all of the zeros.
+    bucketContentSize.set_size(secondHashTable.size());
+    size_t loc = 0;
+    for (size_t i = 0; i < tmpBucketContentSize.n_elem; ++i)
+    {
+      if (tmpBucketContentSize[i] > 0)
+        bucketContentSize[loc++] = tmpBucketContentSize[i];
+
+      // Terminate early, if we can.
+      if (loc == bucketContentSize.n_elem)
+        break;
+    }
+  }
+  else
+  {
+    ar & CreateNVP(bucketContentSize, "bucketContentSize");
+  }
+
   ar & CreateNVP(bucketRowInHashTable, "bucketRowInHashTable");
   ar & CreateNVP(distanceEvaluations, "distanceEvaluations");
 }
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 7b6beec..5dbb9aa 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1225,8 +1225,16 @@ BOOST_AUTO_TEST_CASE(LSHTest)
   BOOST_REQUIRE_EQUAL(lsh.BucketSize(), textLsh.BucketSize());
   BOOST_REQUIRE_EQUAL(lsh.BucketSize(), binaryLsh.BucketSize());
 
-  CheckMatrices(lsh.SecondHashTable(), xmlLsh.SecondHashTable(),
-      textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
+  BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+      xmlLsh.SecondHashTable().size());
+  BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+      textLsh.SecondHashTable().size());
+  BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+      binaryLsh.SecondHashTable().size());
+
+  for (size_t i = 0; i < lsh.SecondHashTable().size(); ++i)
+  CheckMatrices(lsh.SecondHashTable()[i], xmlLsh.SecondHashTable()[i],
+      textLsh.SecondHashTable()[i], binaryLsh.SecondHashTable()[i]);
 }
 
 // Make sure serialization works for the decision stump.




More information about the mlpack-git mailing list