[mlpack-git] master: Revert "Refactor for faster assembly of secondHashTable." (5b8fdce)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Jun 5 01:33:17 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/06cae1319ab1c22b53793c2ceff45605ad0ecc64...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit 5b8fdce471328f722fcd8c0f22a6d995ce22c98b
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Jun 5 01:33:17 2016 -0400
Revert "Refactor for faster assembly of secondHashTable."
This reverts commit 940a2b5766bc613f0db2e55ed4d8498cf287a62c. I did not mean to
commit that.
>---------------------------------------------------------------
5b8fdce471328f722fcd8c0f22a6d995ce22c98b
src/mlpack/methods/lsh/lsh_search.hpp | 2 +-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 72 ++++++++++++++----------------
2 files changed, 35 insertions(+), 39 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index a755a99..b42bb7a 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -322,7 +322,7 @@ class LSHSearch
arma::Col<size_t> bucketContentSize;
//! For a particular hash value, points to the row in secondHashTable
- //! corresponding to this value. Length secondHashSize.
+ //! corresponding to this value. Should be secondHashSize.
arma::Col<size_t> bucketRowInHashTable;
//! The number of distance evaluations.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a141aa2..9ab2067 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -60,7 +60,7 @@ LSHSearch(const arma::mat& referenceSet,
// Empty constructor.
template<typename SortPolicy>
LSHSearch<SortPolicy>::LSHSearch() :
- referenceSet(new arma::mat()), // Use an empty dataset.
+ referenceSet(new arma::mat()), // empty dataset
ownsSet(true),
numProj(0),
numTables(0),
@@ -153,6 +153,9 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
bucketRowInHashTable.set_size(secondHashSize);
bucketRowInHashTable.fill(secondHashSize);
+ // Keep track of number of non-empty rows in the 'secondHashTable'.
+ size_t numRowsInTable = 0;
+
// Step II: The offsets for all projections in all tables.
// Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
// as randu(numProj, numTables) * hashWidth.
@@ -180,10 +183,6 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
"tables provided must be equal to numProj");
}
- // We will store the second hash vectors in this matrix; the second hash
- // vector for table i will be held in row i.
- arma::Mat<size_t> secondHashVectors(numTables, referenceSet.n_cols);
-
for (size_t i = 0; i < numTables; i++)
{
// Step IV: create the 'numProj'-dimensional key for each point in each
@@ -205,36 +204,20 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
// Step V: Putting the points in the 'secondHashTable' by hashing the key.
// Now we hash every key, point ID to its corresponding bucket.
- secondHashVectors.row(i) = arma::conv_to<arma::Row<size_t>>::from(
- secondHashWeights.t() * arma::floor(hashMat));
- }
+ arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
- // Normalize hashes (take modulus with secondHashSize).
- secondHashVectors.transform([secondHashSize](size_t val)
- { return val % secondHashSize; });
+ // This gives us the bucket for the corresponding point ID.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ secondHashVec[j] = (double) ((size_t) secondHashVec[j] % secondHashSize);
- // Now, using the hash vectors for each table, count the number of rows we
- // have in the second hash table.
- arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
- for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
- secondHashBinCounts[secondHashVectors[i]]++;
-
- const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
- const size_t maxBucketSize = std::min(arma::max(secondHashBinCounts),
- bucketSize);
- secondHashTable.resize(numRowsInTable, maxBucketSize);
+ Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
- // Next we must assign each point in each table to the right second hash
- // table.
- size_t currentRow = 0;
- for (size_t i = 0; i < numTables; ++i)
- {
// Insert the point in the corresponding row to its bucket in the
// 'secondHashTable'.
- for (size_t j = 0; j < secondHashVectors.n_cols; j++)
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
{
// This is the bucket number.
- size_t hashInd = (size_t) secondHashVectors(i, j);
+ size_t hashInd = (size_t) secondHashVec[j];
// The point ID is 'j'.
// If this is currently an empty bucket, start a new row keep track of
@@ -242,24 +225,37 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
if (bucketContentSize[hashInd] == 0)
{
// Start a new row for hash.
- bucketRowInHashTable[hashInd] = currentRow;
- bucketContentSize[hashInd] = 1;
- secondHashTable(currentRow, 0) = j;
- currentRow++;
+ bucketRowInHashTable[hashInd] = numRowsInTable;
+ secondHashTable(numRowsInTable, 0) = j;
+
+ numRowsInTable++;
}
- else if (bucketContentSize[hashInd] < maxBucketSize)
+
+ else
{
// If bucket is already present in the 'secondHashTable', find the
// corresponding row and insert the point ID in this row unless the
- // bucket is full (in which case we are not inside this else if).
- secondHashTable(bucketRowInHashTable[hashInd],
- bucketContentSize[hashInd]++) = j;
+ // bucket is full, in which case, do nothing.
+ if (bucketContentSize[hashInd] < bucketSize)
+ secondHashTable(bucketRowInHashTable[hashInd],
+ bucketContentSize[hashInd]) = j;
}
+
+ // Increment the count of the points in this bucket.
+ if (bucketContentSize[hashInd] < bucketSize)
+ bucketContentSize[hashInd]++;
} // Loop over all points in the reference set.
} // Loop over tables.
- Log::Info << "Final hash table size: " << numRowsInTable << " x "
- << maxBucketSize << "." << std::endl;
+ // Step VI: Condensing the 'secondHashTable'.
+ size_t maxBucketSize = 0;
+ for (size_t i = 0; i < bucketContentSize.n_elem; i++)
+ if (bucketContentSize[i] > maxBucketSize)
+ maxBucketSize = bucketContentSize[i];
+
+ Log::Info << "Final hash table size: (" << numRowsInTable << " x "
+ << maxBucketSize << ")" << std::endl;
+ secondHashTable.resize(numRowsInTable, maxBucketSize);
}
template<typename SortPolicy>
More information about the mlpack-git
mailing list