[mlpack-svn] r14033 - mlpack/trunk/src/mlpack/methods/lsh
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 20 19:14:38 EST 2012
Author: rcurtin
Date: 2012-12-20 19:14:38 -0500 (Thu, 20 Dec 2012)
New Revision: 14033
Modified:
mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
Minor formatting improvements and warning fix.
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2012-12-20 23:45:12 UTC (rev 14032)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2012-12-21 00:14:38 UTC (rev 14033)
@@ -3,8 +3,8 @@
* @author Parikshit Ram
*
* Defines the LSHSearch class, which performs an approximate
- * nearest neighbor search for a queries in a query set
- * over a given dataset using Locality-sensitive hashing
+ * nearest neighbor search for a queries in a query set
+ * over a given dataset using Locality-sensitive hashing
* with 2-stable distributions.
*
* The details of this method can be found in the following paper:
@@ -12,14 +12,14 @@
* @inproceedings{datar2004locality,
* title={Locality-sensitive hashing scheme based on p-stable distributions},
* author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
- * booktitle={Proceedings of the 12th Annual Symposium on Computational Geometry},
- * pages={253--262},A
+ * booktitle=
+ * {Proceedings of the 12th Annual Symposium on Computational Geometry},
+ * pages={253--262},
* year={2004},
* organization={ACM}
* }
*
*/
-
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
@@ -31,39 +31,35 @@
#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
namespace mlpack {
-namespace neighbor /** Neighbor-search routines. These include
- * all-nearest-neighbors and all-furthest-neighbors
- * searches. */ {
+namespace neighbor {
/**
* The LSHSearch class -- TBD
- *
+ *
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
* @tparam MetricType The metric to use for computation.
*/
template<typename SortPolicy = NearestNeighborSort,
typename MetricType = mlpack::metric::SquaredEuclideanDistance>
-
class LSHSearch
{
public:
-
/**
* Intialize -- TBD
*
* @param referenceSet Set of reference points.
* @param querySet Set of query points.
- * @param numProj Number of projections in each hash table (anything between
+ * @param numProj Number of projections in each hash table (anything between
* 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20 should
+ * @param numTables Total number of hash tables (anything between 10-20 should
* should suffice).
- * @param hashWidth The width of hash for every table (currently automatically
- * chosen from the main function). This should be a reasonable upper bound
+ * @param hashWidth The width of hash for every table (currently automatically
+ * chosen from the main function). This should be a reasonable upper bound
* on the nearest-neighbor distance in general.
- * @param secondHashSize The size of the second hash table. This should be a
+ * @param secondHashSize The size of the second hash table. This should be a
* large prime number.
- * @param bucketSize The size of the bucket in the second hash table. This is
- * the maximum number of points that can be hashed into single bucket.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
* Default values are already provided here.
* @param metric An optional instance of the MetricType class.
*/
@@ -80,17 +76,17 @@
* Intialize -- TBD
*
* @param referenceSet Set of reference points and the set of queries.
- * @param numProj Number of projections in each hash table (anything between
+ * @param numProj Number of projections in each hash table (anything between
* 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20 should
+ * @param numTables Total number of hash tables (anything between 10-20 should
* should suffice).
- * @param hashWidth The width of hash for every table (currently automatically
- * chosen from the main function). This should be a reasonable upper bound
+ * @param hashWidth The width of hash for every table (currently automatically
+ * chosen from the main function). This should be a reasonable upper bound
* on the nearest-neighbor distance in general.
- * @param secondHashSize The size of the second hash table. This should be a
+ * @param secondHashSize The size of the second hash table. This should be a
* large prime number.
- * @param bucketSize The size of the bucket in the second hash table. This is
- * the maximum number of points that can be hashed into single bucket.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
* Default values are already provided here.
* @param metric An optional instance of the MetricType class.
*/
@@ -120,43 +116,40 @@
* point.
*/
void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
+ arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances);
private:
-
/**
- * This function builds a hash table with two levels of hashing
- * as presented in the paper. This function first hashes the points
- * with 'numProj' random projections to a single hash table creating
- * (key, point ID) pairs where the key is a 'numProj'-dimensional
- * integer vector.
- *
- * Then each key in this hash table is hashed into a second hash table
- * using a standard hash.
+ * This function builds a hash table with two levels of hashing as presented
+ * in the paper. This function first hashes the points with 'numProj' random
+ * projections to a single hash table creating (key, point ID) pairs where the
+ * key is a 'numProj'-dimensional integer vector.
*
- * This function does not have any parameters and relies on parameters
- * which are private members of this class, intialized during the
- * class intialization.
+ * Then each key in this hash table is hashed into a second hash table using a
+ * standard hash.
+ *
+ * This function does not have any parameters and relies on parameters which
+ * are private members of this class, intialized during the class
+ * intialization.
*/
void BuildHash();
-
/**
- * This function takes a query and hashes it into each of the hash tables
- * to get keys for the query and then the key is hashed to a bucket of the
- * second hash table and all the points (if any) in those buckets
- * are collected as the potential neighbor candidates.
+ * This function takes a query and hashes it into each of the hash tables to
+ * get keys for the query and then the key is hashed to a bucket of the second
+ * hash table and all the points (if any) in those buckets are collected as
+ * the potential neighbor candidates.
*
* @param queryIndex The index of the query currently being processed.
- * @param referenceIndices The list of neighbor candidates obtained from
- * hashing the query into all the hash tables and eventually into
+ * @param referenceIndices The list of neighbor candidates obtained from
+ * hashing the query into all the hash tables and eventually into
* multiple buckets of the second hash table.
*/
void ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices);
/**
- * This is a helper function that computes the distance of the query to the
+ * This is a helper function that computes the distance of the query to the
* neighbor candidates and appropriately stores the best 'k' candidates
*
* @param queryIndex The index of the query in question
@@ -165,16 +158,16 @@
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
/**
- * This is a helper function that efficiently inserts better neighbor
- * candidates into an existing set of neighbor candidates. This function
- * is only called by the 'BaseCase' function.
+ * This is a helper function that efficiently inserts better neighbor
+ * candidates into an existing set of neighbor candidates. This function is
+ * only called by the 'BaseCase' function.
*
* @param queryIndex This is the index of the query being processed currently
- * @param pos The position of the neighbor candidate in the current list of
+ * @param pos The position of the neighbor candidate in the current list of
* neighbor candidates.
- * @param neighbor The neighbor candidate that is being inserted into the list
+ * @param neighbor The neighbor candidate that is being inserted into the list
* of the best 'k' candidates for the query in question.
- * @param distance The distance of the query to the neighbor candidate.
+ * @param distance The distance of the query to the neighbor candidate.
*/
void InsertNeighbor(const size_t queryIndex, const size_t pos,
const size_t neighbor, const double distance);
@@ -186,9 +179,6 @@
//! Query dataset (may not be given).
const arma::mat& querySet;
- //! Instantiation of the metric.
- MetricType metric;
-
//! The number of projections
const size_t numProj;
@@ -213,6 +203,9 @@
//! The bucket size of the second hash
const size_t bucketSize;
+ //! Instantiation of the metric.
+ MetricType metric;
+
//! The final hash table
arma::Mat<size_t> secondHashTable; // should be (< secondHashSize) x bucketSize
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2012-12-20 23:45:12 UTC (rev 14032)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2012-12-21 00:14:38 UTC (rev 14033)
@@ -44,7 +44,7 @@
const double hashWidth,
const size_t secondHashSize,
const size_t bucketSize,
- const MetricType metric) :
+ const MetricType metric) :
referenceSet(referenceSet),
querySet(referenceSet),
numProj(numProj),
@@ -91,7 +91,7 @@
template<typename SortPolicy, typename MetricType>
-inline //force_inline
+inline //force_inline
double LSHSearch<SortPolicy, MetricType>::
BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
@@ -121,7 +121,7 @@
ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices)
{
- // Hash the query in each of the 'numTables' hash tables using the
+ // Hash the query in each of the 'numTables' hash tables using the
// 'numProj' projections for each table.
// This gives us 'numTables' keys for the query where each key
// is a 'numProj' dimensional integer vector
@@ -129,13 +129,13 @@
// compute the projection of the query in each table
arma::mat allProjInTables(numProj, numTables);
for (size_t i = 0; i < numTables; i++)
- allProjInTables.unsafe_col(i)
+ allProjInTables.unsafe_col(i)
= projections[i].t() * querySet.unsafe_col(queryIndex);
allProjInTables += offsets;
allProjInTables /= hashWidth;
- // compute the hash value of each key of the query into a bucket of the
- // 'secondHashTable' using the 'secondHashWeights'.
+ // compute the hash value of each key of the query into a bucket of the
+ // 'secondHashTable' using the 'secondHashWeights'.
arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
for (size_t i = 0; i < hashVec.n_elem; i++)
@@ -143,8 +143,8 @@
assert(hashVec.n_elem == numTables);
- // For all the buckets that the query is hashed into, sequentially
- // collect the indices in those buckets.
+ // For all the buckets that the query is hashed into, sequentially
+ // collect the indices in those buckets.
arma::Col<size_t> refPointsConsidered;
refPointsConsidered.zeros(referenceSet.n_cols);
@@ -191,7 +191,7 @@
// go through every query point sequentially
for (size_t i = 0; i < querySet.n_cols; i++)
{
- // For hash every query into every hash tables and eventually
+ // For hash every query into every hash tables and eventually
// into the 'secondHashTable' to obtain the neighbor candidates
arma::uvec refIndices;
ReturnIndicesFromTable(i, refIndices);
@@ -222,48 +222,48 @@
// The first level hash for a single table outputs a 'numProj'-dimensional
// integer key for each point in the set -- (key, pointID)
// The key creation details are presented below
- //
- // The second level hash is performed by hashing the key to
+ //
+ // The second level hash is performed by hashing the key to
// an integer in the range [0, 'secondHashSize').
//
- // This is done by creating a weight vector 'secondHashWeights' of
+ // This is done by creating a weight vector 'secondHashWeights' of
// length 'numProj' with each entry an integer randomly chosen
- // between [0, 'secondHashSize').
+ // between [0, 'secondHashSize').
//
- // Then the bucket for any key and its corresponding point is
+ // Then the bucket for any key and its corresponding point is
// given by <key, 'secondHashWeights'> % 'secondHashSize'
- // and the corresponding point ID is put into that bucket.
+ // and the corresponding point ID is put into that bucket.
//////////////////////////////////////////
// Step I: Preparing the second level hash
///////////////////////////////////////////
// obtain the weights for the second hash
- secondHashWeights = arma::floor(arma::randu(numProj)
+ secondHashWeights = arma::floor(arma::randu(numProj)
* (double) secondHashSize);
- // The 'secondHashTable' is initially an empty matrix of size
- // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
- // as points land in them allows us to shrink the size of the
+ // The 'secondHashTable' is initially an empty matrix of size
+ // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
+ // as points land in them allows us to shrink the size of the
// 'secondHashTable' at the end of the hashing.
// Start filling up the second hash table
secondHashTable.set_size(secondHashSize, bucketSize);
// Fill the second hash table n = referenceSet.n_cols
- // This is because no point has index 'n' so the presence of
- // this in the bucket denotes that there are no more points
- // in this bucket.
+ // This is because no point has index 'n' so the presence of
+ // this in the bucket denotes that there are no more points
+ // in this bucket.
secondHashTable.fill(referenceSet.n_cols);
// Keeping track of the size of each bucket in the hash.
// At the end of hashing most buckets will be empty.
bucketContentSize.zeros(secondHashSize);
- // Instead of putting the points in the row corresponding to
- // the bucket, we chose the next empty row and keep track of
- // the row in which the bucket lies. This allows us to
- // stack together and slice out the empty buckets at the
+ // Instead of putting the points in the row corresponding to
+ // the bucket, we chose the next empty row and keep track of
+ // the row in which the bucket lies. This allows us to
+ // stack together and slice out the empty buckets at the
// end of the hashing.
bucketRowInHashTable.set_size(secondHashSize);
bucketRowInHashTable.fill(secondHashSize);
@@ -293,7 +293,7 @@
// Step IV: Obtaining the 'numProj' projections for each table
//////////////////////////////////////////////////////////////
//
- // For L2 metric, 2-stable distributions are used, and
+ // For L2 metric, 2-stable distributions are used, and
// the normal Z ~ N(0, 1) is a 2-stable distribution.
arma::mat projMat;
projMat.randn(referenceSet.n_rows, numProj);
@@ -302,16 +302,16 @@
projections.push_back(projMat);
///////////////////////////////////////////////////////////////
- // Step V: create the 'numProj'-dimensional key for each point
+ // Step V: create the 'numProj'-dimensional key for each point
// in each table.
//////////////////////////////////////////////////////////////
- // The following set of lines performs the task of
- // hashing each point to a 'numProj'-dimensional integer key.
+ // The following set of lines performs the task of
+ // hashing each point to a 'numProj'-dimensional integer key.
// Hence you get a ('numProj' x 'referenceSet.n_cols') key matrix
//
- // For a single table, let the 'numProj' projections be denoted
- // by 'proj_i' and the corresponding offset be 'offset_i'.
+ // For a single table, let the 'numProj' projections be denoted
+ // by 'proj_i' and the corresponding offset be 'offset_i'.
// Then the key of a single point is obtained as:
// key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i),
@@ -321,7 +321,7 @@
hashMat /= hashWidth;
////////////////////////////////////////////////////////////
- // Step VI: Putting the points in the 'secondHashTable' by
+ // Step VI: Putting the points in the 'secondHashTable' by
// hashing the key.
///////////////////////////////////////////////////////////
@@ -335,7 +335,7 @@
assert(secondHashVec.n_elem == referenceSet.n_cols);
- // Inserting the point in the corresponding row to its bucket
+ // Inserting the point in the corresponding row to its bucket
// in the 'secondHashTable'.
for (size_t j = 0; j < secondHashVec.n_elem; j++)
{
@@ -343,7 +343,7 @@
size_t hashInd = (size_t) secondHashVec[j];
// The point ID is 'j'
- // If this is currently an empty bucket, start a new row
+ // If this is currently an empty bucket, start a new row
// keep track of which row corresponds to the bucket.
if (bucketContentSize[hashInd] == 0)
{
@@ -353,14 +353,14 @@
numRowsInTable++;
}
- // If bucket already present in the 'secondHashTable', find
- // the corresponding row and insert the point ID in this row
+ // If bucket already present in the 'secondHashTable', find
+ // the corresponding row and insert the point ID in this row
// unless the bucket is full, in which case, do nothing.
- else
+ else
{
- // if bucket not full, insert point here
+ // if bucket not full, insert point here
if (bucketContentSize[hashInd] < bucketSize)
- secondHashTable(bucketRowInHashTable[hashInd],
+ secondHashTable(bucketRowInHashTable[hashInd],
bucketContentSize[hashInd]) = j;
// else just ignore as suggested
}
@@ -376,12 +376,12 @@
// Step VII: Condensing the 'secondHashTable'
/////////////////////////////////////////////////
- size_t maxBucketSize = 0;
+ size_t maxBucketSize = 0;
for (size_t i = 0; i < bucketContentSize.n_elem; i++)
if (bucketContentSize[i] > maxBucketSize)
maxBucketSize = bucketContentSize[i];
- Log::Info << "Final hash table size: (" << numRowsInTable << " x "
+ Log::Info << "Final hash table size: (" << numRowsInTable << " x "
<< maxBucketSize << ")" << std::endl;
secondHashTable.resize(numRowsInTable, maxBucketSize);
More information about the mlpack-svn
mailing list