[mlpack-git] master: Switch secondHashTable to vector<Col<size_t>>. (bd65c87)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 8 14:33:12 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8d7e5db0bed8fc236407bdc5dee00d716d72a5ab...ae6c9e63b56c1ed1faa9aef9352854bbeb826a2f
>---------------------------------------------------------------
commit bd65c877e6b62bdcbc1f1fb46100587e54e74cec
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jun 8 14:33:12 2016 -0400
Switch secondHashTable to vector<Col<size_t>>.
This should provide a good amount of speedup, and also save RAM.
>---------------------------------------------------------------
bd65c877e6b62bdcbc1f1fb46100587e54e74cec
src/mlpack/methods/lsh/lsh_search.hpp | 8 +-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 144 ++++++++++++++++++++---------
src/mlpack/tests/serialization_test.cpp | 12 ++-
3 files changed, 116 insertions(+), 48 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index a755a99..7cbe1e6 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -197,7 +197,8 @@ class LSHSearch
size_t BucketSize() const { return bucketSize; }
//! Get the second hash table.
- const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
+ const std::vector<arma::Col<size_t>>& SecondHashTable() const
+ { return secondHashTable; }
//! Get the projection tables.
const arma::cube& Projections() { return projections; }
@@ -314,8 +315,9 @@ class LSHSearch
//! The bucket size of the second hash.
size_t bucketSize;
- //! The final hash table; should be (< secondHashSize) x bucketSize.
- arma::Mat<size_t> secondHashTable;
+ //! The final hash table; should be (< secondHashSize) vectors each with
+ //! (<= bucketSize) elements.
+ std::vector<arma::Col<size_t>> secondHashTable;
//! The number of elements present in each hash bucket; should be
//! secondHashSize.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a141aa2..bd022c8 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -131,21 +131,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
secondHashWeights = arma::floor(arma::randu(numProj) *
(double) secondHashSize);
- // The 'secondHashTable' is initially an empty matrix of size
- // ('secondHashSize' x 'bucketSize'). But by only filling the buckets as
- // points land in them allows us to shrink the size of the 'secondHashTable'
- // at the end of the hashing.
-
- // Fill the second hash table n = referenceSet.n_cols. This is because no
- // point has index 'n' so the presence of this in the bucket denotes that
- // there are no more points in this bucket.
- secondHashTable.set_size(secondHashSize, bucketSize);
- secondHashTable.fill(referenceSet.n_cols);
-
- // Keep track of the size of each bucket in the hash. At the end of hashing
- // most buckets will be empty.
- bucketContentSize.zeros(secondHashSize);
-
// Instead of putting the points in the row corresponding to the bucket, we
// chose the next empty row and keep track of the row in which the bucket
// lies. This allows us to stack together and slice out the empty buckets at
@@ -219,10 +204,13 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
secondHashBinCounts[secondHashVectors[i]]++;
+ // Enforce the maximum bucket size.
+ secondHashBinCounts.transform([bucketSize](size_t val)
+ { return std::min(val, bucketSize); });
+
const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
- const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
- bucketSize);
- secondHashTable.resize(numRowsInTable, maxBucketSize);
+ bucketContentSize.zeros(numRowsInTable);
+ secondHashTable.resize(numRowsInTable);
// Next we must assign each point in each table to the right second hash
// table.
@@ -239,27 +227,26 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
// If this is currently an empty bucket, start a new row keep track of
// which row corresponds to the bucket.
- if (bucketContentSize[hashInd] == 0)
+ const size_t maxSize = secondHashBinCounts[hashInd];
+ if (bucketRowInHashTable[hashInd] == secondHashSize)
{
- // Start a new row for hash.
bucketRowInHashTable[hashInd] = currentRow;
- bucketContentSize[hashInd] = 1;
- secondHashTable(currentRow, 0) = j;
+ secondHashTable[currentRow].set_size(maxSize);
currentRow++;
}
- else if (bucketContentSize[hashInd] < maxBucketSize)
- {
- // If bucket is already present in the 'secondHashTable', find the
- // corresponding row and insert the point ID in this row unless the
- // bucket is full (in which case we are not inside this else if).
- secondHashTable(bucketRowInHashTable[hashInd],
- bucketContentSize[hashInd]++) = j;
- }
+
+ // If this vector in the hash table is not full, add the point.
+ const size_t index = bucketRowInHashTable[hashInd];
+ if (bucketContentSize[index] < maxSize)
+ secondHashTable[index](bucketContentSize[index]++) = j;
+
} // Loop over all points in the reference set.
} // Loop over tables.
- Log::Info << "Final hash table size: " << numRowsInTable << " x "
- << maxBucketSize << "." << std::endl;
+ Log::Info << "Final hash table size: " << numRowsInTable << " rows, with a "
+ << "maximum length of " << arma::max(secondHashBinCounts) << ", "
+ << "totaling " << arma::accu(secondHashBinCounts) << " elements."
+ << std::endl;
}
template<typename SortPolicy>
@@ -388,17 +375,14 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
{
- size_t hashInd = (size_t) hashVec[i];
+ const size_t hashInd = (size_t) hashVec[i];
+ const size_t tableRow = bucketRowInHashTable[hashInd];
- if (bucketContentSize[hashInd] > 0)
+ if ((tableRow != secondHashSize) && (bucketContentSize[tableRow] > 0))
{
// Pick the indices in the bucket corresponding to 'hashInd'.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
-
- for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
- refPointsConsidered[secondHashTable(tableRow, j)]++;
+ for (size_t j = 0; j < bucketContentSize[tableRow]; j++)
+ refPointsConsidered[secondHashTable[tableRow](j)]++;
}
}
@@ -540,7 +524,7 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
if (Archive::is_loading::value)
projections.reset();
- // Backward compatibility: older version of LSHSearch stored the projection
+ // Backward compatibility: older versions of LSHSearch stored the projection
// tables in a std::vector<arma::mat>.
if (version == 0)
{
@@ -561,8 +545,82 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
ar & CreateNVP(secondHashSize, "secondHashSize");
ar & CreateNVP(secondHashWeights, "secondHashWeights");
ar & CreateNVP(bucketSize, "bucketSize");
- ar & CreateNVP(secondHashTable, "secondHashTable");
- ar & CreateNVP(bucketContentSize, "bucketContentSize");
+ // needs specific handling for new version
+
+ // Backward compatibility: in older versions of LSHSearch, the secondHashTable
+ // was stored as an arma::Mat<size_t>. So we need to properly load that, then
+ // prune it down to size.
+ if (version == 0)
+ {
+ arma::Mat<size_t> tmpSecondHashTable;
+ ar & CreateNVP(tmpSecondHashTable, "secondHashTable");
+
+ secondHashTable.resize(tmpSecondHashTable.n_cols);
+ for (size_t i = 0; i < tmpSecondHashTable.n_cols; ++i)
+ {
+ // Find length of each column. We know we are at the end of the list when
+ // the value referenceSet->n_cols is seen.
+ size_t len = 0;
+ for ( ; len < tmpSecondHashTable.n_rows; ++len)
+ if (tmpSecondHashTable(len, i) == referenceSet->n_cols)
+ break;
+
+ // Set the size of the new column correctly.
+ secondHashTable[i].set_size(len);
+ for (size_t j = 0; j < len; ++j)
+ secondHashTable[i](j) = tmpSecondHashTable(j, i);
+ }
+ }
+ else
+ {
+ size_t tables;
+ if (Archive::is_saving::value)
+ tables = secondHashTable.size();
+ ar & CreateNVP(tables, "numSecondHashTables");
+
+ // Set size of second hash table if needed.
+ if (Archive::is_loading::value)
+ {
+ secondHashTable.clear();
+ secondHashTable.resize(tables);
+ }
+
+ for (size_t i = 0; i < secondHashTable.size(); ++i)
+ {
+ std::ostringstream oss;
+ oss << "secondHashTable" << i;
+ ar & CreateNVP(secondHashTable[i], oss.str());
+ }
+ }
+
+ // Backward compatibility: old versions of LSHSearch held bucketContentSize
+ // for all possible buckets (of size secondHashSize), but now we hold a
+ // compressed representation.
+ if (version == 0)
+ {
+ // The vector was stored in the old uncompressed form. So we need to shrink
+ // it.
+ arma::Col<size_t> tmpBucketContentSize;
+ ar & CreateNVP(tmpBucketContentSize, "bucketContentSize");
+
+ // Compress into a smaller vector by just dropping all of the zeros.
+ bucketContentSize.set_size(secondHashTable.size());
+ size_t loc = 0;
+ for (size_t i = 0; i < tmpBucketContentSize.n_elem; ++i)
+ {
+ if (tmpBucketContentSize[i] > 0)
+ bucketContentSize[loc++] = tmpBucketContentSize[i];
+
+ // Terminate early, if we can.
+ if (loc == bucketContentSize.n_elem)
+ break;
+ }
+ }
+ else
+ {
+ ar & CreateNVP(bucketContentSize, "bucketContentSize");
+ }
+
ar & CreateNVP(bucketRowInHashTable, "bucketRowInHashTable");
ar & CreateNVP(distanceEvaluations, "distanceEvaluations");
}
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 7b6beec..5dbb9aa 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1225,8 +1225,16 @@ BOOST_AUTO_TEST_CASE(LSHTest)
BOOST_REQUIRE_EQUAL(lsh.BucketSize(), textLsh.BucketSize());
BOOST_REQUIRE_EQUAL(lsh.BucketSize(), binaryLsh.BucketSize());
- CheckMatrices(lsh.SecondHashTable(), xmlLsh.SecondHashTable(),
- textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
+ BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+ xmlLsh.SecondHashTable().size());
+ BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+ textLsh.SecondHashTable().size());
+ BOOST_REQUIRE_EQUAL(lsh.SecondHashTable().size(),
+ binaryLsh.SecondHashTable().size());
+
+ for (size_t i = 0; i < lsh.SecondHashTable().size(); ++i)
+ CheckMatrices(lsh.SecondHashTable()[i], xmlLsh.SecondHashTable()[i],
+ textLsh.SecondHashTable()[i], binaryLsh.SecondHashTable()[i]);
}
// Make sure serialization works for the decision stump.
More information about the mlpack-git
mailing list