[mlpack-svn] r14136 - mlpack/trunk/src/mlpack/methods/lsh
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jan 18 17:00:49 EST 2013
Author: rcurtin
Date: 2013-01-18 17:00:49 -0500 (Fri, 18 Jan 2013)
New Revision: 14136
Modified:
mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
Clean up LSH code and add random seed parameter.
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2013-01-18 22:00:49 UTC (rev 14136)
@@ -2,8 +2,8 @@
* @file lsh_main.cpp
* @author Parikshit Ram
*
- * This file computes the approximate nearest-neighbors using 2-stable
- * Locality-sensitive Hashing.
+ * This file computes the approximate nearest-neighbors using 2-stable
+ * Locality-sensitive Hashing.
*/
#include <time.h>
@@ -22,48 +22,59 @@
// Information about the program itself.
PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
- "This program will calculate the k approximate-nearest-neighbors "
- "of a set of points. You may specify a separate set of reference "
- "points and query points, or just a reference set which will be "
- "used as both the reference and query set. "
+ "This program will calculate the k approximate-nearest-neighbors of a set "
+ "of points using locality-sensitive hashing. You may specify a separate set"
+ " of reference points and query points, or just a reference set which will "
+ "be used as both the reference and query set. "
"\n\n"
- "For example, the following will return 5 neighbors from the "
- "data for each point in 'input.csv' "
- "and store the distances in 'distances.csv' and the neighbors in the "
- "file 'neighbors.csv':"
+ "For example, the following will return 5 neighbors from the data for each "
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
"\n\n"
- "$ ./lsh/lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
+ "$ lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
"\n\n"
"The output files are organized such that row i and column j in the "
"neighbors output file corresponds to the index of the point in the "
"reference set which is the i'th nearest neighbor from the point in the "
"query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points.");
+ "corresponds to the distance between those two points."
+ "\n\n"
+ "Because this is approximate-nearest-neighbors search, results may be "
+ "different from run to run. Thus, the --seed option can be specified to "
+ "set the random seed.");
// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.", "r");
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
PARAM_STRING("distances_file", "File to output distances into.", "d", "");
PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+
PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-PARAM_INT("num_projections", "The number of hash functions for each table",
- "K", 10);
-PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
-PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing "
- "in the LSH preprocessing. By default, the LSH class "
- "automatically estimates a hash width for its use.", "H", 0.0);
-PARAM_INT("second_hash_size", "The size of the second level hash table.",
- "M", 99901);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.",
- "B", 500);
+PARAM_INT("projections", "The number of hash functions for each table", "K",
+ 10);
+PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
+PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
+ "LSH preprocessing. By default, the LSH class automatically estimates a "
+ "hash width for its use.", "H", 0.0);
+PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
+ 99901);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
+ 500);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
CLI::ParseCommandLine(argc, argv);
- math::RandomSeed(time(NULL));
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) time(NULL));
+
// Get all the parameters.
string referenceFile = CLI::GetParam<string>("reference_file");
string distancesFile = CLI::GetParam<string>("distances_file");
@@ -89,15 +100,11 @@
Log::Fatal << referenceData.n_cols << ")." << endl;
}
+ // Pick up the LSH-specific parameters.
+ const size_t numProj = CLI::GetParam<int>("num_projections");
+ const size_t numTables = CLI::GetParam<int>("num_tables");
+ const double hashWidth = CLI::GetParam<double>("hash_width");
- // Pick up the 'K' and the 'L' parameter for LSH
- size_t numProj = CLI::GetParam<int>("num_projections");
- size_t numTables = CLI::GetParam<int>("num_tables");
-
- // Compute the 'hash_width' parameter from LSH
- double hashWidth = CLI::GetParam<double>("hash_width");
-
-
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -111,11 +118,11 @@
}
if (hashWidth == 0.0)
- Log::Info << "LSH with " << numProj << " projections(K) and " <<
- numTables << " tables(L) with default hash width." << endl;
+ Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+ numTables << " tables (L) with default hash width." << endl;
else
- Log::Info << "LSH with " << numProj << " projections(K) and " <<
- numTables << " tables(L) with hash width(r): " << hashWidth << endl;
+ Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+ numTables << " tables (L) with hash width(r): " << hashWidth << endl;
Timer::Start("hash_building");
@@ -129,21 +136,19 @@
secondHashSize, bucketSize);
Timer::Stop("hash_building");
-
- Log::Info << "Computing " << k << " distance approx. nearest neighbors " <<
- endl;
+
+ Log::Info << "Computing " << k << " distance approximate nearest neighbors "
+ << endl;
allkann->Search(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
// Save output.
- if (distancesFile != "")
+ if (distancesFile != "")
data::Save(distancesFile, distances);
if (neighborsFile != "")
data::Save(neighborsFile, neighbors);
delete allkann;
-
- return 0;
}
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2013-01-18 22:00:49 UTC (rev 14136)
@@ -34,8 +34,8 @@
namespace neighbor {
/**
- * The LSHSearch class -- This class builds a hash on the reference set
- * and uses this hash to compute the distance-approximate nearest-neighbors
+ * The LSHSearch class -- This class builds a hash on the reference set
+ * and uses this hash to compute the distance-approximate nearest-neighbors
* of the given queries.
*
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
@@ -47,21 +47,20 @@
{
public:
/**
- * This function initializes the LSH class. It builds the hash on the
- * reference set with 2-stable distributions. See the individual functions
+ * This function initializes the LSH class. It builds the hash on the
+ * reference set with 2-stable distributions. See the individual functions
* performing the hashing for details on how the hashing is done.
*
* @param referenceSet Set of reference points.
* @param querySet Set of query points.
* @param numProj Number of projections in each hash table (anything between
* 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20
+ * @param numTables Total number of hash tables (anything between 10-20
* should suffice).
- * @param hashWidth The width of hash for every table. If the user does not
- * provide a value then the class automatically obtains a hash width
- * by computing the average pairwise distance of 25 pairs. This should
- * be a reasonable upper bound on the nearest-neighbor distance
- * in general.
+ * @param hashWidth The width of hash for every table. If 0 (the default) is
+ * provided, then the hash width is automatically obtained by computing
+ * the average pairwise distance of 25 pairs. This should be a reasonable
+ * upper bound on the nearest-neighbor distance in general.
* @param secondHashSize The size of the second hash table. This should be a
* large prime number.
* @param bucketSize The size of the bucket in the second hash table. This is
@@ -79,20 +78,19 @@
const MetricType metric = MetricType());
/**
- * This function initializes the LSH class. It builds the hash on the
- * reference set with 2-stable distributions. See the individual functions
+ * This function initializes the LSH class. It builds the hash on the
+ * reference set with 2-stable distributions. See the individual functions
* performing the hashing for details on how the hashing is done.
*
* @param referenceSet Set of reference points and the set of queries.
* @param numProj Number of projections in each hash table (anything between
* 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20
+ * @param numTables Total number of hash tables (anything between 10-20
* should suffice).
- * @param hashWidth The width of hash for every table. If the user does not
- * provide a value then the class automatically obtains a hash width
- * by computing the average pairwise distance of 25 pairs. This should
- * be a reasonable upper bound on the nearest-neighbor distance
- * in general.
+ * @param hashWidth The width of hash for every table. If 0 (the default) is
+ * provided, then the hash width is automatically obtained by computing
+ * the average pairwise distance of 25 pairs. This should be a reasonable
+ * upper bound on the nearest-neighbor distance in general.
* @param secondHashSize The size of the second hash table. This should be a
* large prime number.
* @param bucketSize The size of the bucket in the second hash table. This is
@@ -107,15 +105,10 @@
const size_t secondHashSize = 99901,
const size_t bucketSize = 500,
const MetricType metric = MetricType());
- /**
- * Delete the LSHSearch object. The tree is the only member we are
- * responsible for deleting. The others will take care of themselves.
- */
- ~LSHSearch();
/**
* Compute the nearest neighbors and store the output in the given matrices.
- * The matrices will be set to the size of n columns by k rows, where n is
+ * The matrices will be set to the size of n columns by k rows, where n is
* the number of points in the query dataset and k is the number of neighbors
* being searched for.
*
@@ -125,16 +118,16 @@
* @param distances Matrix storing distances of neighbors for each query
* point.
* @param numTablesToSearch This parameter allows the user to have control
- * over the number of hash tables to be searched. This allows
- * the user to pick the number of tables it can afford for the time
+ * over the number of hash tables to be searched. This allows
+ * the user to pick the number of tables it can afford for the time
* available without having to build hashing for every table size.
- * By default, this is set to zero in which case all tables are
+ * By default, this is set to zero in which case all tables are
* considered.
*/
void Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- size_t numTablesToSearch = 0);
+ const size_t numTablesToSearch = 0);
private:
/**
@@ -166,6 +159,7 @@
void ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices,
size_t numTablesToSearch);
+
/**
* This is a helper function that computes the distance of the query to the
* neighbor candidates and appropriately stores the best 'k' candidates
@@ -224,23 +218,22 @@
//! Instantiation of the metric.
MetricType metric;
- //! The final hash table
- arma::Mat<size_t> secondHashTable; // should be (< secondHashSize) x bucketSize
+ //! The final hash table; should be (< secondHashSize) x bucketSize.
+ arma::Mat<size_t> secondHashTable;
- //! The number of elements present in each hash bucket
- arma::Col<size_t> bucketContentSize; // should be secondHashSize
+ //! The number of elements present in each hash bucket; should be
+ //! secondHashSize.
+ arma::Col<size_t> bucketContentSize;
//! For a particular hash value, points to the row in secondHashTable
- //! corresponding to this value
- arma::Col<size_t> bucketRowInHashTable; // should be secondHashSize
+ //! corresponding to this value. Should be secondHashSize.
+ arma::Col<size_t> bucketRowInHashTable;
- //! The pointer to the nearest neighbor distance
+ //! The pointer to the nearest neighbor distances.
arma::mat* distancePtr;
- //! The pointer to the nearest neighbor indices
+ //! The pointer to the nearest neighbor indices.
arma::Mat<size_t>* neighborPtr;
-
-
}; // class LSHSearch
}; // namespace neighbor
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-01-18 22:00:49 UTC (rev 14136)
@@ -33,20 +33,20 @@
bucketSize(bucketSize),
metric(metric)
{
- if (hashWidth == 0.0) // the user has not provided any value
+ if (hashWidth == 0.0) // The user has not provided any value.
{
+ // Compute a heuristic hash width from the data.
for (size_t i = 0; i < 25; i++)
{
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth
- += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2));
+ hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2));
}
hashWidth /= 25;
- } // computing a heuristic hashWidth from the data
+ }
BuildHash();
}
@@ -69,32 +69,26 @@
bucketSize(bucketSize),
metric(metric)
{
- if (hashWidth == 0.0) // the user has not provided any value
+ if (hashWidth == 0.0) // The user has not provided any value.
{
+ // Compute a heuristic hash width from the data.
for (size_t i = 0; i < 25; i++)
{
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth
- += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2));
+ hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2));
}
hashWidth /= 25;
- } // computing a heuristic hashWidth from the data
+ }
BuildHash();
}
template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
-~LSHSearch()
-{ }
-
-
-template<typename SortPolicy, typename MetricType>
void LSHSearch<SortPolicy, MetricType>::
InsertNeighbor(const size_t queryIndex,
const size_t pos,
@@ -152,50 +146,51 @@
arma::uvec& referenceIndices,
size_t numTablesToSearch)
{
- // deciding on the number of tables to look into.
- if (numTablesToSearch == 0) // if no user input, search all
+ // Decide on the number of tables to look into.
+ if (numTablesToSearch == 0) // If no user input is given, search all.
numTablesToSearch = numTables;
-
- // sanity check to make sure that the existing number of tables is not
+ // Sanity check to make sure that the existing number of tables is not
// exceeded.
if (numTablesToSearch > numTables)
numTablesToSearch = numTables;
// Hash the query in each of the 'numTablesToSearch' hash tables using the
- // 'numProj' projections for each table.
- // This gives us 'numTablesToSearch' keys for the query where each key
- // is a 'numProj' dimensional integer vector
- //
- // compute the projection of the query in each table
+ // 'numProj' projections for each table. This gives us 'numTablesToSearch'
+ // keys for the query where each key is a 'numProj' dimensional integer
+ // vector.
+
+ // Compute the projection of the query in each table.
arma::mat allProjInTables(numProj, numTablesToSearch);
for (size_t i = 0; i < numTablesToSearch; i++)
- allProjInTables.unsafe_col(i)
- = projections[i].t() * querySet.unsafe_col(queryIndex);
+ {
+ allProjInTables.unsafe_col(i) = projections[i].t() *
+ querySet.unsafe_col(queryIndex);
+ }
allProjInTables += offsets.cols(0, numTablesToSearch - 1);
allProjInTables /= hashWidth;
- // compute the hash value of each key of the query into a bucket of the
+ // Compute the hash value of each key of the query into a bucket of the
// 'secondHashTable' using the 'secondHashWeights'.
arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
for (size_t i = 0; i < hashVec.n_elem; i++)
- hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
+ hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
- assert(hashVec.n_elem == numTablesToSearch);
+ Log::Assert(hashVec.n_elem == numTablesToSearch);
// For all the buckets that the query is hashed into, sequentially
// collect the indices in those buckets.
arma::Col<size_t> refPointsConsidered;
refPointsConsidered.zeros(referenceSet.n_cols);
- for (size_t i = 0; i < hashVec.n_elem; i++)
+ for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
{
size_t hashInd = (size_t) hashVec[i];
if (bucketContentSize[hashInd] > 0)
{
- // Pick the indices in the bucket corresponding to 'hashInd'
+ // Pick the indices in the bucket corresponding to 'hashInd'.
size_t tableRow = bucketRowInHashTable[hashInd];
assert(tableRow < secondHashSize);
assert(tableRow < secondHashTable.n_rows);
@@ -203,10 +198,9 @@
for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
refPointsConsidered[secondHashTable(tableRow, j)]++;
}
- } // for all tables
+ }
referenceIndices = arma::find(refPointsConsidered > 0);
- return;
}
@@ -215,7 +209,7 @@
Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- size_t numTablesToSearch)
+ const size_t numTablesToSearch)
{
neighborPtr = &resultingNeighbors;
distancePtr = &distances;
@@ -230,20 +224,20 @@
Timer::Start("computing_neighbors");
- // go through every query point sequentially
+ // Go through every query point sequentially.
for (size_t i = 0; i < querySet.n_cols; i++)
{
- // For hash every query into every hash tables and eventually
- // into the 'secondHashTable' to obtain the neighbor candidates
+ // Hash every query into every hash table and eventually into the
+ // 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
- // Just an informative book-keeping for the number of neighbor candidates
- // returned on average
+ // An informative book-keeping for the number of neighbor candidates
+ // returned on average.
avgIndicesReturned += refIndices.n_elem;
// Sequentially go through all the candidates and save the best 'k'
- // candidates
+ // candidates.
for (size_t j = 0; j < refIndices.n_elem; j++)
BaseCase(i, (size_t) refIndices[j]);
}
@@ -252,9 +246,7 @@
avgIndicesReturned /= querySet.n_cols;
Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
- std::endl;
-
- return;
+ std::endl;
}
template<typename SortPolicy, typename MetricType>
@@ -276,148 +268,121 @@
// given by <key, 'secondHashWeights'> % 'secondHashSize'
// and the corresponding point ID is put into that bucket.
- //////////////////////////////////////////
- // Step I: Preparing the second level hash
- ///////////////////////////////////////////
+ // Step I: Prepare the second level hash.
- // obtain the weights for the second hash
- secondHashWeights = arma::floor(arma::randu(numProj)
- * (double) secondHashSize);
+ // Obtain the weights for the second hash.
+ secondHashWeights = arma::floor(arma::randu(numProj) *
+ (double) secondHashSize);
// The 'secondHashTable' is initially an empty matrix of size
// ('secondHashSize' x 'bucketSize'). But by only filling the buckets
// as points land in them allows us to shrink the size of the
// 'secondHashTable' at the end of the hashing.
- // Start filling up the second hash table
+ // Fill the second hash table n = referenceSet.n_cols. This is because no
+ // point has index 'n' so the presence of this in the bucket denotes that
+ // there are no more points in this bucket.
secondHashTable.set_size(secondHashSize, bucketSize);
-
- // Fill the second hash table n = referenceSet.n_cols
- // This is because no point has index 'n' so the presence of
- // this in the bucket denotes that there are no more points
- // in this bucket.
secondHashTable.fill(referenceSet.n_cols);
- // Keeping track of the size of each bucket in the hash.
- // At the end of hashing most buckets will be empty.
+ // Keep track of the size of each bucket in the hash. At the end of hashing
+ // most buckets will be empty.
bucketContentSize.zeros(secondHashSize);
- // Instead of putting the points in the row corresponding to
- // the bucket, we chose the next empty row and keep track of
- // the row in which the bucket lies. This allows us to
- // stack together and slice out the empty buckets at the
- // end of the hashing.
+ // Instead of putting the points in the row corresponding to the bucket, we
+ // chose the next empty row and keep track of the row in which the bucket
+ // lies. This allows us to stack together and slice out the empty buckets at
+ // the end of the hashing.
bucketRowInHashTable.set_size(secondHashSize);
bucketRowInHashTable.fill(secondHashSize);
- // keeping track of number of non-empty rows in the 'secondHashTable'
+ // Keep track of number of non-empty rows in the 'secondHashTable'.
size_t numRowsInTable = 0;
-
- /////////////////////////////////////////////////////////
- // Step II: The offsets for all projections in all tables
- /////////////////////////////////////////////////////////
-
+ // Step II: The offsets for all projections in all tables.
// Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
- // as randu(numProj, numTables) * hashWidth
+ // as randu(numProj, numTables) * hashWidth.
offsets.randu(numProj, numTables);
offsets *= hashWidth;
- /////////////////////////////////////////////////////////////////
- // Step III: Creating each hash table in the first level hash
- // one by one and putting them directly into the 'secondHashTable'
- // for memory efficiency.
- /////////////////////////////////////////////////////////////////
-
- for(size_t i = 0; i < numTables; i++)
+ // Step III: Create each hash table in the first level hash one by one and
+ // putting them directly into the 'secondHashTable' for memory efficiency.
+ for (size_t i = 0; i < numTables; i++)
{
- //////////////////////////////////////////////////////////////
- // Step IV: Obtaining the 'numProj' projections for each table
- //////////////////////////////////////////////////////////////
- //
+ // Step IV: Obtain the 'numProj' projections for each table.
+
// For L2 metric, 2-stable distributions are used, and
// the normal Z ~ N(0, 1) is a 2-stable distribution.
arma::mat projMat;
projMat.randn(referenceSet.n_rows, numProj);
- // save the projection matrix for querying
+ // Save the projection matrix for querying.
projections.push_back(projMat);
- ///////////////////////////////////////////////////////////////
- // Step V: create the 'numProj'-dimensional key for each point
- // in each table.
- //////////////////////////////////////////////////////////////
+ // Step V: create the 'numProj'-dimensional key for each point in each
+ // table.
- // The following set of lines performs the task of
- // hashing each point to a 'numProj'-dimensional integer key.
- // Hence you get a ('numProj' x 'referenceSet.n_cols') key matrix
+ // The following code performs the task of hashing each point to a
+ // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
+ // 'referenceSet.n_cols') key matrix.
//
- // For a single table, let the 'numProj' projections be denoted
- // by 'proj_i' and the corresponding offset be 'offset_i'.
- // Then the key of a single point is obtained as:
+ // For a single table, let the 'numProj' projections be denoted by 'proj_i'
+ // and the corresponding offset be 'offset_i'. Then the key of a single
+ // point is obtained as:
// key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
- arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i),
- 1, referenceSet.n_cols);
+ arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
+ referenceSet.n_cols);
arma::mat hashMat = projMat.t() * referenceSet;
hashMat += offsetMat;
hashMat /= hashWidth;
- ////////////////////////////////////////////////////////////
- // Step VI: Putting the points in the 'secondHashTable' by
- // hashing the key.
- ///////////////////////////////////////////////////////////
-
- // Now we hash every key, point ID to its corresponding bucket
+ // Step VI: Putting the points in the 'secondHashTable' by hashing the key.
+ // Now we hash every key, point ID to its corresponding bucket.
arma::rowvec secondHashVec = secondHashWeights.t()
* arma::floor(hashMat);
- // This gives us the bucket for the corresponding point ID
+ // This gives us the bucket for the corresponding point ID.
for (size_t j = 0; j < secondHashVec.n_elem; j++)
secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
- assert(secondHashVec.n_elem == referenceSet.n_cols);
+ Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
- // Inserting the point in the corresponding row to its bucket
- // in the 'secondHashTable'.
+ // Insert the point in the corresponding row to its bucket in the
+ // 'secondHashTable'.
for (size_t j = 0; j < secondHashVec.n_elem; j++)
{
- // This is the bucket number
+ // This is the bucket number.
size_t hashInd = (size_t) secondHashVec[j];
- // The point ID is 'j'
+ // The point ID is 'j'.
- // If this is currently an empty bucket, start a new row
- // keep track of which row corresponds to the bucket.
+ // If this is currently an empty bucket, start a new row keep track of
+ // which row corresponds to the bucket.
if (bucketContentSize[hashInd] == 0)
{
- // start a new row for hash
+ // Start a new row for hash.
bucketRowInHashTable[hashInd] = numRowsInTable;
secondHashTable(numRowsInTable, 0) = j;
numRowsInTable++;
}
- // If bucket already present in the 'secondHashTable', find
- // the corresponding row and insert the point ID in this row
- // unless the bucket is full, in which case, do nothing.
+
else
{
- // if bucket not full, insert point here
+ // If bucket is already present in the 'secondHashTable', find the
+ // corresponding row and insert the point ID in this row unless the
+ // bucket is full, in which case, do nothing.
if (bucketContentSize[hashInd] < bucketSize)
secondHashTable(bucketRowInHashTable[hashInd],
bucketContentSize[hashInd]) = j;
- // else just ignore as suggested
}
- // increment the count of the points in this bucket
+ // Increment the count of the points in this bucket.
if (bucketContentSize[hashInd] < bucketSize)
bucketContentSize[hashInd]++;
- } // loop over all points in the reference set
- } // loop over tables
+ } // Loop over all points in the reference set.
+ } // Loop over tables.
-
- /////////////////////////////////////////////////
- // Step VII: Condensing the 'secondHashTable'
- /////////////////////////////////////////////////
-
+ // Step VII: Condensing the 'secondHashTable'.
size_t maxBucketSize = 0;
for (size_t i = 0; i < bucketContentSize.n_elem; i++)
if (bucketContentSize[i] > maxBucketSize)
@@ -426,9 +391,6 @@
Log::Info << "Final hash table size: (" << numRowsInTable << " x "
<< maxBucketSize << ")" << std::endl;
secondHashTable.resize(numRowsInTable, maxBucketSize);
-
- return;
}
-
#endif
More information about the mlpack-svn
mailing list