[mlpack-git] master: Merge branch 'LSHTableAccess' of https://github.com/mentekid/mlpack into mentekid-LSHTableAccess (e0b6ce7)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jun 2 09:43:25 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit e0b6ce7cbd9ea2c6e9411ad4063115eeedc63b8e
Merge: eba4f99 06fdfa8
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jun 2 09:43:25 2016 -0400
Merge branch 'LSHTableAccess' of https://github.com/mentekid/mlpack into mentekid-LSHTableAccess
>---------------------------------------------------------------
e0b6ce7cbd9ea2c6e9411ad4063115eeedc63b8e
HISTORY.md | 10 +
.../core/data/serialization_template_version.hpp | 38 +++
src/mlpack/methods/lsh/lsh_search.hpp | 69 ++++-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 333 ++++++++++++---------
src/mlpack/prereqs.hpp | 1 +
src/mlpack/tests/serialization_test.cpp | 4 +-
6 files changed, 297 insertions(+), 158 deletions(-)
diff --cc src/mlpack/methods/lsh/lsh_search_impl.hpp
index 02d0021,c0fc57e..c3734aa
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@@ -97,7 -120,152 +120,147 @@@ void LSHSearch<SortPolicy>::Train(cons
Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
- BuildHash();
- // Hash Building Procedure
++ // Hash building procedure:
+ // The first level hash for a single table outputs a 'numProj'-dimensional
- // integer key for each point in the set -- (key, pointID)
- // The key creation details are presented below
- //
++ // integer key for each point in the set -- (key, pointID). The key creation
++ // details are presented below.
+
+ // Step I: Prepare the second level hash.
+
+ // Obtain the weights for the second hash.
+ secondHashWeights = arma::floor(arma::randu(numProj) *
+ (double) secondHashSize);
+
+ // The 'secondHashTable' is initially an empty matrix of size
- // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
- // as points land in them allows us to shrink the size of the
- // 'secondHashTable' at the end of the hashing.
++ // ('secondHashSize' x 'bucketSize'). But by only filling the buckets as
++ // points land in them allows us to shrink the size of the 'secondHashTable'
++ // at the end of the hashing.
+
+ // Fill the second hash table n = referenceSet.n_cols. This is because no
+ // point has index 'n' so the presence of this in the bucket denotes that
+ // there are no more points in this bucket.
+ secondHashTable.set_size(secondHashSize, bucketSize);
+ secondHashTable.fill(referenceSet.n_cols);
+
+ // Keep track of the size of each bucket in the hash. At the end of hashing
+ // most buckets will be empty.
+ bucketContentSize.zeros(secondHashSize);
+
+ // Instead of putting the points in the row corresponding to the bucket, we
+ // chose the next empty row and keep track of the row in which the bucket
+ // lies. This allows us to stack together and slice out the empty buckets at
+ // the end of the hashing.
+ bucketRowInHashTable.set_size(secondHashSize);
+ bucketRowInHashTable.fill(secondHashSize);
+
+ // Keep track of number of non-empty rows in the 'secondHashTable'.
+ size_t numRowsInTable = 0;
+
+ // Step II: The offsets for all projections in all tables.
+ // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
+ // as randu(numProj, numTables) * hashWidth.
+ offsets.randu(numProj, numTables);
+ offsets *= hashWidth;
+
-
-
-
+ // Step III: Obtain the 'numProj' projections for each table.
+ projections.clear(); // Reset projections vector.
+
+ if (projection.n_slices == 0) //random generation of tables
+ {
- // For L2 metric, 2-stable distributions are used, and
- // the normal Z ~ N(0, 1) is a 2-stable distribution.
++ // For L2 metric, 2-stable distributions are used, and the normal Z ~ N(0,
++ // 1) is a 2-stable distribution.
+
- //numTables random tables arranged in a cube
++ // numTables random tables arranged in a cube.
+ projections.randn(
+ referenceSet.n_rows,
+ numProj,
+ numTables
+ );
+ }
+ else if (projection.n_slices == numTables) //user defined tables
+ {
+ projections = projection;
+ }
+ else //invalid argument
+ {
+ throw std::invalid_argument(
+ "number of projection tables provided must be equal to numProj"
+ );
+ }
+
-
+ for (size_t i = 0; i < numTables; i++)
+ {
+ // Step IV: create the 'numProj'-dimensional key for each point in each
+ // table.
+
+ // The following code performs the task of hashing each point to a
+ // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
+ // 'referenceSet.n_cols') key matrix.
+ //
+ // For a single table, let the 'numProj' projections be denoted by 'proj_i'
+ // and the corresponding offset be 'offset_i'. Then the key of a single
+ // point is obtained as:
+ // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
+ arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
+ referenceSet.n_cols);
+ arma::mat hashMat = projections.slice(i).t() * (referenceSet);
+ hashMat += offsetMat;
+ hashMat /= hashWidth;
+
+ // Step V: Putting the points in the 'secondHashTable' by hashing the key.
+ // Now we hash every key, point ID to its corresponding bucket.
+ arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
+
+ // This gives us the bucket for the corresponding point ID.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
+
+ Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+
+ // Insert the point in the corresponding row to its bucket in the
+ // 'secondHashTable'.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ {
+ // This is the bucket number.
+ size_t hashInd = (size_t) secondHashVec[j];
+ // The point ID is 'j'.
+
+ // If this is currently an empty bucket, start a new row keep track of
+ // which row corresponds to the bucket.
+ if (bucketContentSize[hashInd] == 0)
+ {
+ // Start a new row for hash.
+ bucketRowInHashTable[hashInd] = numRowsInTable;
+ secondHashTable(numRowsInTable, 0) = j;
+
+ numRowsInTable++;
+ }
+
+ else
+ {
+ // If bucket is already present in the 'secondHashTable', find the
+ // corresponding row and insert the point ID in this row unless the
+ // bucket is full, in which case, do nothing.
+ if (bucketContentSize[hashInd] < bucketSize)
+ secondHashTable(bucketRowInHashTable[hashInd],
+ bucketContentSize[hashInd]) = j;
+ }
+
+ // Increment the count of the points in this bucket.
+ if (bucketContentSize[hashInd] < bucketSize)
+ bucketContentSize[hashInd]++;
+ } // Loop over all points in the reference set.
+ } // Loop over tables.
+
+ // Step VI: Condensing the 'secondHashTable'.
+ size_t maxBucketSize = 0;
+ for (size_t i = 0; i < bucketContentSize.n_elem; i++)
+ if (bucketContentSize[i] > maxBucketSize)
+ maxBucketSize = bucketContentSize[i];
+
+ Log::Info << "Final hash table size: (" << numRowsInTable << " x "
+ << maxBucketSize << ")" << std::endl;
+ secondHashTable.resize(numRowsInTable, maxBucketSize);
}
template<typename SortPolicy>
More information about the mlpack-git
mailing list