[mlpack-git] master: Refactor for faster assembly of secondHashTable. (940a2b5)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Jun 5 01:31:50 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit 940a2b5766bc613f0db2e55ed4d8498cf287a62c
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Jun 3 20:25:54 2016 -0400
Refactor for faster assembly of secondHashTable.
>---------------------------------------------------------------
940a2b5766bc613f0db2e55ed4d8498cf287a62c
src/mlpack/methods/lsh/lsh_search.hpp | 2 +-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 72 ++++++++++++++++--------------
2 files changed, 39 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..a755a99 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -322,7 +322,7 @@ class LSHSearch
arma::Col<size_t> bucketContentSize;
//! For a particular hash value, points to the row in secondHashTable
- //! corresponding to this value. Should be secondHashSize.
+ //! corresponding to this value. Length secondHashSize.
arma::Col<size_t> bucketRowInHashTable;
//! The number of distance evaluations.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 9ab2067..a141aa2 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -60,7 +60,7 @@ LSHSearch(const arma::mat& referenceSet,
// Empty constructor.
template<typename SortPolicy>
LSHSearch<SortPolicy>::LSHSearch() :
- referenceSet(new arma::mat()), // empty dataset
+ referenceSet(new arma::mat()), // Use an empty dataset.
ownsSet(true),
numProj(0),
numTables(0),
@@ -153,9 +153,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
bucketRowInHashTable.set_size(secondHashSize);
bucketRowInHashTable.fill(secondHashSize);
- // Keep track of number of non-empty rows in the 'secondHashTable'.
- size_t numRowsInTable = 0;
-
// Step II: The offsets for all projections in all tables.
// Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
// as randu(numProj, numTables) * hashWidth.
@@ -183,6 +180,10 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
"tables provided must be equal to numProj");
}
+ // We will store the second hash vectors in this matrix; the second hash
+ // vector for table i will be held in row i.
+ arma::Mat<size_t> secondHashVectors(numTables, referenceSet.n_cols);
+
for (size_t i = 0; i < numTables; i++)
{
// Step IV: create the 'numProj'-dimensional key for each point in each
@@ -204,20 +205,36 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
// Step V: Putting the points in the 'secondHashTable' by hashing the key.
// Now we hash every key, point ID to its corresponding bucket.
- arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
+ secondHashVectors.row(i) = arma::conv_to<arma::Row<size_t>>::from(
+ secondHashWeights.t() * arma::floor(hashMat));
+ }
- // This gives us the bucket for the corresponding point ID.
- for (size_t j = 0; j < secondHashVec.n_elem; j++)
- secondHashVec[j] = (double) ((size_t) secondHashVec[j] % secondHashSize);
+ // Normalize hashes (take modulus with secondHashSize).
+ secondHashVectors.transform([secondHashSize](size_t val)
+ { return val % secondHashSize; });
- Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+ // Now, using the hash vectors for each table, count the number of rows we
+ // have in the second hash table.
+ arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
+ for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
+ secondHashBinCounts[secondHashVectors[i]]++;
+
+ const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
+ const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
+ bucketSize);
+ secondHashTable.resize(numRowsInTable, maxBucketSize);
+ // Next we must assign each point in each table to the right second hash
+ // table.
+ size_t currentRow = 0;
+ for (size_t i = 0; i < numTables; ++i)
+ {
// Insert the point in the corresponding row to its bucket in the
// 'secondHashTable'.
- for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ for (size_t j = 0; j < secondHashVectors.n_cols; j++)
{
// This is the bucket number.
- size_t hashInd = (size_t) secondHashVec[j];
+ size_t hashInd = (size_t) secondHashVectors(i, j);
// The point ID is 'j'.
// If this is currently an empty bucket, start a new row keep track of
@@ -225,37 +242,24 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
if (bucketContentSize[hashInd] == 0)
{
// Start a new row for hash.
- bucketRowInHashTable[hashInd] = numRowsInTable;
- secondHashTable(numRowsInTable, 0) = j;
-
- numRowsInTable++;
+ bucketRowInHashTable[hashInd] = currentRow;
+ bucketContentSize[hashInd] = 1;
+ secondHashTable(currentRow, 0) = j;
+ currentRow++;
}
-
- else
+ else if (bucketContentSize[hashInd] < maxBucketSize)
{
// If bucket is already present in the 'secondHashTable', find the
// corresponding row and insert the point ID in this row unless the
- // bucket is full, in which case, do nothing.
- if (bucketContentSize[hashInd] < bucketSize)
- secondHashTable(bucketRowInHashTable[hashInd],
- bucketContentSize[hashInd]) = j;
+ // bucket is full (in which case we are not inside this else if).
+ secondHashTable(bucketRowInHashTable[hashInd],
+ bucketContentSize[hashInd]++) = j;
}
-
- // Increment the count of the points in this bucket.
- if (bucketContentSize[hashInd] < bucketSize)
- bucketContentSize[hashInd]++;
} // Loop over all points in the reference set.
} // Loop over tables.
- // Step VI: Condensing the 'secondHashTable'.
- size_t maxBucketSize = 0;
- for (size_t i = 0; i < bucketContentSize.n_elem; i++)
- if (bucketContentSize[i] > maxBucketSize)
- maxBucketSize = bucketContentSize[i];
-
- Log::Info << "Final hash table size: (" << numRowsInTable << " x "
- << maxBucketSize << ")" << std::endl;
- secondHashTable.resize(numRowsInTable, maxBucketSize);
+ Log::Info << "Final hash table size: " << numRowsInTable << " x "
+ << maxBucketSize << "." << std::endl;
}
template<typename SortPolicy>
More information about the mlpack-git
mailing list