[mlpack-svn] r14032 - mlpack/trunk/src/mlpack/methods/lsh
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 20 18:45:12 EST 2012
Author: rcurtin
Date: 2012-12-20 18:45:12 -0500 (Thu, 20 Dec 2012)
New Revision: 14032
Added:
mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp
Modified:
mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
Update to newest version (I mistakenly branched an old version).
Modified: mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt 2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt 2012-12-20 23:45:12 UTC (rev 14032)
@@ -19,12 +19,34 @@
set(MLPACK_CONTRIB_SRCS ${MLPACK_CONTRIB_SRCS} ${DIR_SRCS} PARENT_SCOPE)
-# The code to compute the rank-approximate neighbor
+# The code to compute the approximate neighbor
# for the given query and reference sets
-add_executable(allkdann
+# with p-stable LSH
+add_executable(lsh
lsh_main.cpp
)
-target_link_libraries(allkdann
+target_link_libraries(lsh
${MLPACK_LIBRARY}
+)
+
+# The code to compute the approximate neighbor
+# for the given query and reference sets
+# with p-stable LSH and analyze the time-accuracy
+# tradeoff
+add_executable(lsh_analysis
+ lsh_analysis_main.cpp
+)
+target_link_libraries(lsh_analysis
+ ${MLPACK_LIBRARY}
contrib_pram
)
+
+
+# Testing the LSHSearch class
+add_executable(lsh_test
+ lsh_test.cpp
+)
+target_link_libraries(lsh_test
+ ${MLPACK_LIBRARY}
+)
+
Added: mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp 2012-12-20 23:45:12 UTC (rev 14032)
@@ -0,0 +1,367 @@
+/**
+ * @file lsh_analysis_main.cpp
+ * @author Parikshit Ram
+ *
+ * This main file computes the accuracy-time tradeoff of
+ * the 2-stable LSH class 'LSHSearch'
+ */
+#include <time.h>
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "lsh_search.hpp"
+#include "../utils/utils.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+// Information about the program itself.
+PROGRAM_INFO("LSH accuracy-time tradeoff analysis.",
+ "This program will calculate the k approximate-nearest-neighbors "
+ "of a set of queries under different parameter settings."
+ " You may specify a separate set of reference "
+ "points and query points, or just a reference set which will be "
+ "used as both the reference and query set. "
+ "\n\n"
+ "For a given set of queries and references and 'k', this program "
+ "uses different parameters for LSH and computes the error and "
+ "the query time. For the computation of error, it requires the "
+ "files containing the ranks and the NN distances for each query. "
+ "\n\n"
+ "Sample usage: \n"
+ "./lsh/lsh_analysis -r reference.csv -q queries.csv -k 5 "
+ " -P -W -E query_ranks_file.csv -D query_nn_dist_file.csv "
+ " -F error_v_time_output.csv ");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_STRING_REQ("rank_file", "The file containing the true ranks.", "E");
+PARAM_STRING_REQ("nn_dist_file", "The file containing the true distance"
+ " errors.", "D");
+PARAM_INT("num_projections", "The number of hash functions for each table",
+ "K", 10);
+PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
+PARAM_INT("second_hash_size", "The size of the second level hash table.",
+ "M", 8807);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.",
+ "B", 500);
+
+PARAM_FLAG("try_diff_params", "The flag to trigger the search with "
+ "different 'K', 'L' and 'r'.", "P");
+PARAM_FLAG("try_diff_widths", "The flag to trigger the search with "
+ "different hash widths.", "W");
+
+PARAM_STRING("error_time_file", "File to output the error v. time report to.",
+ "F", "");
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+ math::RandomSeed(time(NULL));
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ size_t k = CLI::GetParam<int>("k");
+ size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
+ size_t bucketSize = CLI::GetParam<int>("bucket_size");
+
+ bool tryDiffParams = CLI::HasParam("try_diff_params");
+ bool tryDiffWidths = CLI::HasParam("try_diff_widths");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile.c_str(), referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' (" <<
+ referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+
+ // Pick up the 'K' and the 'L' parameter for LSH
+ arma::Col<size_t> numProjs, numTables;
+ if (tryDiffParams)
+ {
+ numProjs.set_size(4);
+ numProjs << 10 << 25 << 40 << 55;
+ numTables.set_size(4);
+ numTables << 5 << 10 << 15 << 20;
+ }
+ else
+ {
+ numProjs.set_size(1);
+ numProjs[0] = CLI::GetParam<int>("num_projections");
+ numTables.set_size(1);
+ numTables[0] = CLI::GetParam<int>("num_tables");
+ }
+
+ // Compute the 'width' parameter from LSH
+
+ // Find the average pairwise distance of 25 random pairs
+ double avgDist = 0;
+ for (size_t i = 0; i < 25; i++)
+ {
+ size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
+ size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
+ avgDist += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
+ referenceData.unsafe_col(p2));
+ }
+
+ avgDist /= 25;
+
+ Log::Info << "Hash width chosen as: " << avgDist << endl;
+
+ arma::vec hashWidths;
+ if (tryDiffWidths)
+ {
+ arma::vec eps(3);
+ eps << 0.01 << 0.1 << 1.0;
+ hashWidths = avgDist * eps;
+ }
+ else
+ {
+ hashWidths.set_size(1);
+ hashWidths[0] = avgDist;
+ }
+
+ arma::vec timesTaken(numProjs.n_elem * numTables.n_elem * hashWidths.n_elem);
+ timesTaken.zeros();
+
+ arma::Mat<size_t> allNeighbors;
+ arma::mat allNeighborDistances;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' (" <<
+ queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ allNeighbors.set_size(k * timesTaken.n_elem, queryData.n_cols);
+ allNeighborDistances.set_size(k * timesTaken.n_elem, queryData.n_cols);
+ }
+ else
+ {
+ allNeighbors.set_size(k * timesTaken.n_elem, referenceData.n_cols);
+ allNeighborDistances.set_size(k * timesTaken.n_elem, referenceData.n_cols);
+ }
+
+ size_t exptInd = 0;
+ arma::mat exptParams(timesTaken.n_elem, 3);
+
+ // Looping through all combinations of the parameters and noting the
+ // runtime and the results.
+ for (size_t widthInd = 0; widthInd < hashWidths.n_elem; widthInd++)
+ {
+ for (size_t projInd = 0; projInd < numProjs.n_elem; projInd++)
+ {
+ for (size_t tableInd = 0; tableInd < numTables.n_elem; tableInd++)
+ {
+ Log::Info << "LSH with K: " << numProjs[projInd] << ", L: " <<
+ numTables[tableInd] << ", r: " << hashWidths[widthInd] << endl;
+
+ Timer::Start("hash_building");
+
+ LSHSearch<>* allkann;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ allkann = new LSHSearch<>(referenceData, queryData, numProjs[projInd],
+ numTables[tableInd], hashWidths[widthInd],
+ secondHashSize, bucketSize);
+ else
+ allkann = new LSHSearch<>(referenceData, numProjs[projInd],
+ numTables[tableInd], hashWidths[widthInd],
+ secondHashSize, bucketSize);
+
+ Timer::Stop("hash_building");
+
+ timeval start_tv = Timer::Get("computing_neighbors");
+ double startTime
+ = (double) start_tv.tv_sec + (double) start_tv.tv_usec / 1.0e6;
+
+ Log::Info << "Computing " << k << " approx. nearest neighbors" << endl;
+ allkann->Search(k, neighbors, distances);
+ Log::Info << "Neighbors computed." << endl;
+
+ timeval stop_tv = Timer::Get("computing_neighbors");
+ double stopTime
+ = (double) stop_tv.tv_sec + (double) stop_tv.tv_usec / 1.0e6;
+
+ exptParams(exptInd, 0) = (double) numProjs[projInd];
+ exptParams(exptInd, 1) = (double) numTables[tableInd];
+ exptParams(exptInd, 2) = hashWidths[widthInd];
+ timesTaken[exptInd] = stopTime - startTime;
+
+ // add results to big matrix
+ allNeighbors.rows(exptInd * k, (exptInd + 1) * k - 1) = neighbors;
+ allNeighborDistances.rows(exptInd * k, (exptInd + 1) * k - 1)
+ = distances;
+
+ exptInd++;
+
+ neighbors.reset();
+ distances.reset();
+
+ delete allkann;
+ } // diff. L
+ } // diff. K
+ } // diff. 'width'
+
+
+ // Computing the errors
+ string rankFile = CLI::GetParam<string>("rank_file");
+ Log::Warn << "Computing error..." << endl;
+
+ contrib_utils::LineReader lr(rankFile);
+
+ arma::mat allANNErrors(timesTaken.n_elem, 8);
+ // 0 - K
+ // 1 - L
+ // 2 - width
+ // 3 - Time taken
+ // 4 - Mean Rank/Recall
+ // 5 - Median Rank/Recall
+ // 6 - StdDev Rank/Recall
+ // 7 - MaxRank / MinRecall
+
+
+ // If k == 1, compute the rank and distance errors
+ if (k == 1)
+ {
+ string distFile = CLI::GetParam<string>("nn_dist_file");
+ arma::mat nnDistsMat;
+
+ if (!data::Load(distFile, nnDistsMat))
+ Log::Fatal << "Dist file " << distFile << " cannot be loaded." << endl;
+
+ arma::vec nnDists(nnDistsMat.row(0).t());
+
+ allANNErrors.resize(timesTaken.n_elem, 12);
+ // 8 - Mean DE
+ // 9 - Median DE
+ // 10 - StdDev DE
+ // 11 - Max DE
+
+ arma::mat ranks(timesTaken.n_elem, allNeighbors.n_cols);
+ arma::mat des(timesTaken.n_elem, allNeighbors.n_cols);
+
+ for (size_t i = 0; i < allNeighbors.n_cols; i++)
+ {
+ arma::Col<size_t> true_ranks(referenceData.n_cols);
+ lr.ReadLine(&true_ranks);
+
+ for (size_t j = 0; j < timesTaken.n_elem; j++)
+ {
+ if (allNeighbors(j, i) < referenceData.n_cols)
+ {
+ ranks(j, i) = (double) true_ranks[allNeighbors(j, i)];
+ des(j, i) = (allNeighborDistances(j, i) - nnDists[i]) / nnDists[i];
+ }
+ // not sure what to do in terms of distance error
+ // in case no result is returned
+ else
+ {
+ ranks(j, i) = (double) referenceData.n_cols;
+ des(j, i) = 1000;
+ }
+ }
+ }
+
+ // Saving the LSH parameters and the query times
+ allANNErrors.cols(0, 2) = exptParams;
+ allANNErrors.col(3) = timesTaken;
+
+ // Saving the mean distance error and rank over all queries.
+ allANNErrors.col(4) = arma::mean(ranks, 1);
+ allANNErrors.col(5) = arma::median(ranks, 1);
+ allANNErrors.col(6) = arma::stddev(ranks, 1, 1);
+ allANNErrors.col(7) = arma::max(ranks, 1);
+
+ allANNErrors.col(8) = arma::mean(des, 1);
+ allANNErrors.col(9) = arma::median(des, 1);
+ allANNErrors.col(10) = arma::stddev(des, 1, 1);
+ allANNErrors.col(11) = arma::max(des, 1);
+
+ } // if k == 1
+ // if k > 1, compute the recall of the k-nearest-neighbors.
+ else
+ {
+ arma::mat recalls(timesTaken.n_elem, allNeighbors.n_cols);
+ recalls.zeros();
+
+ for (size_t i = 0; i < allNeighbors.n_cols; i++)
+ {
+ arma::Col<size_t> true_ranks(referenceData.n_cols);
+ lr.ReadLine(&true_ranks);
+
+ for (size_t j = 0; j < timesTaken.n_elem; j++)
+ {
+ for (size_t ind = 0; ind < k; ind++)
+ {
+ if (allNeighbors(j * k + ind, i) < referenceData.n_cols)
+ {
+ if (true_ranks[allNeighbors(j * k + ind, i)] <= k)
+ recalls(j, i)++;
+ }
+ }
+ }
+ }
+
+ recalls /= k;
+
+ // Saving the LSH parameters and the query times
+ allANNErrors.cols(0, 2) = exptParams;
+ allANNErrors.col(3) = timesTaken;
+
+ // Saving the mean recall of the k-nearest-neighbor over all queries.
+ allANNErrors.col(4) = arma::mean(recalls, 1);
+ allANNErrors.col(5) = arma::median(recalls, 1);
+ allANNErrors.col(6) = arma::stddev(recalls, 1, 1);
+ allANNErrors.col(7) = arma::min(recalls, 1);
+
+ } // if k > 1, compute recall of k-NN
+
+ Log::Warn << allANNErrors;
+
+
+ // Saving the output in a file
+ string annErrorOutputFile = CLI::GetParam<string>("error_time_file");
+
+ if (annErrorOutputFile != "")
+ {
+ allANNErrors = allANNErrors.t();
+ data::Save(annErrorOutputFile, allANNErrors);
+ }
+ else
+ {
+ Log::Warn << "Params: " << endl << exptParams.t() << "Times Taken: " <<
+ endl << timesTaken.t();
+ }
+
+ return 0;
+}
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,8 +2,8 @@
* @file lsh_main.cpp
* @author Parikshit Ram
*
- * Implementation of LSH with a 2-stable distribution of
- * nearest neighbor search in Euclidean space.
+ * This file computes the approximate nearest-neighbors using 2-stable
+ * Locality-sensitive Hashing.
*/
#include <time.h>
@@ -15,15 +15,14 @@
#include <iostream>
#include "lsh_search.hpp"
-#include "../utils/utils.hpp"
using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
// Information about the program itself.
-PROGRAM_INFO("All K-distance-Approximate-Nearest-Neighbors with LSH",
- "This program will calculate the k distance-approximate-nearest-neighbors "
+PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
+ "This program will calculate the k approximate-nearest-neighbors "
"of a set of points. You may specify a separate set of reference "
"points and query points, or just a reference set which will be "
"used as both the reference and query set. "
@@ -33,8 +32,7 @@
"and store the distances in 'distances.csv' and the neighbors in the "
"file 'neighbors.csv':"
"\n\n"
- "$ allkdann --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
- " --neighbors_file=neighbors.csv"
+ "$ ./lsh/lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
"\n\n"
"The output files are organized such that row i and column j in the "
"neighbors output file corresponds to the index of the point in the "
@@ -46,25 +44,17 @@
PARAM_STRING_REQ("reference_file", "File containing the reference dataset.", "r");
PARAM_STRING("distances_file", "File to output distances into.", "d", "");
PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-
PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-
PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-PARAM_INT("num_projections", "The number of hash functions for each table", "K", 10);
+PARAM_INT("num_projections", "The number of hash functions for each table",
+ "K", 10);
PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", "M", 8807);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B", 500);
+PARAM_INT("second_hash_size", "The size of the second level hash table.",
+ "M", 99901);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.",
+ "B", 500);
-PARAM_FLAG("try_diff_params", "The flag to trigger the search with "
- "different 'K', 'L' and 'r'.", "P");
-PARAM_FLAG("try_diff_widths", "The flag to trigger the search with "
- "different hash widths.", "W");
-
-PARAM_STRING("rank_file", "The file containing the true ranks.", "E", "");
-PARAM_STRING("de_file", "The file containing the true distance errors.", "D", "");
-PARAM_STRING("ann_error_file", "File to output the RANN errors to.", "F", "");
-
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
@@ -80,9 +70,6 @@
size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
size_t bucketSize = CLI::GetParam<int>("bucket_size");
- bool tryDiffParams = CLI::HasParam("try_diff_params");
- bool tryDiffWidths = CLI::HasParam("try_diff_widths");
-
arma::mat referenceData;
arma::mat queryData; // So it doesn't go out of scope.
data::Load(referenceFile.c_str(), referenceData, true);
@@ -101,58 +88,27 @@
// Pick up the 'K' and the 'L' parameter for LSH
- arma::Col<size_t> numProjs, numTables;
- if (tryDiffParams)
- {
- numProjs.set_size(3);
- numProjs << 10 << 25 << 50;
- numTables.set_size(5);
- numTables << 5 << 10 << 25 << 50 << 100;
- }
- else
- {
- numProjs.set_size(1);
- numProjs[0] = CLI::GetParam<int>("num_projections");
- numTables.set_size(1);
- numTables[0] = CLI::GetParam<int>("num_tables");
- }
-
+ size_t numProj = CLI::GetParam<int>("num_projections");
+ size_t numTables = CLI::GetParam<int>("num_tables");
+
// Compute the 'width' parameter from LSH
- // Find the average pairwise distance of 25 random pairs
- double avgDist = 0;
+ // Find the average pairwise distance of 25 random pairs and use that
+ // as the hash-width
+ double hashWidth = 0;
for (size_t i = 0; i < 25; i++)
{
- size_t p1 = (size_t) math::RandInt(referenceData.n_cols),
- p2 = (size_t) math::RandInt(referenceData.n_cols);
+ size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
+ size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
- avgDist += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
- referenceData.unsafe_col(p2));
+ hashWidth += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
+ referenceData.unsafe_col(p2));
}
- avgDist /= 25;
+ hashWidth /= 25;
- Log::Info << "Hash width chosen as: " << avgDist << endl;
+ Log::Info << "Hash width chosen as: " << hashWidth << endl;
- arma::vec hashWidths;
- if (tryDiffWidths)
- {
- arma::vec eps(5);
- eps << 0.001 << 0.01 << 0.1 << 1.0 << 10.0;
- hashWidths = avgDist * eps;
-
- }
- else
- {
- hashWidths.set_size(1);
- hashWidths[0] = avgDist;
- }
-
- arma::vec timesTaken(numProjs.n_elem * numTables.n_elem * hashWidths.n_elem);
- timesTaken.zeros();
-
- arma::Mat<size_t> allNeighbors;
-
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -163,73 +119,29 @@
data::Load(queryFile.c_str(), queryData, true);
Log::Info << "Loaded query data from '" << queryFile << "' ("
<< queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- allNeighbors.set_size(k * timesTaken.n_elem, queryData.n_cols);
}
- else
- allNeighbors.set_size(k * timesTaken.n_elem, referenceData.n_cols);
- size_t exptInd = 0;
- arma::mat exptParams(timesTaken.n_elem, 3);
+ Log::Info << "LSH with " << numProj << " projections(K) and " << numTables <<
+ " tables(L) with hash width(r): " << hashWidth << endl;
- for (size_t widthInd = 0; widthInd < hashWidths.n_elem; widthInd++)
- {
- for (size_t projInd = 0; projInd < numProjs.n_elem; projInd++)
- {
- for (size_t tableInd = 0; tableInd < numTables.n_elem; tableInd++)
- {
- Log::Info << "LSH with K: " << numProjs[projInd] << ", L: "
- << numTables[tableInd] << ", r: " << hashWidths[widthInd] << endl;
+ Timer::Start("hash_building");
- Timer::Start("hash_building");
+ LSHSearch<>* allkann;
- LSHSearch<>* allkdann;
+ if (CLI::GetParam<string>("query_file") != "")
+ allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
+ hashWidth, secondHashSize, bucketSize);
+ else
+ allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
+ secondHashSize, bucketSize);
- if (CLI::GetParam<string>("query_file") != "")
- allkdann = new LSHSearch<>(referenceData, queryData, numProjs[projInd],
- numTables[tableInd], hashWidths[widthInd],
- secondHashSize, bucketSize);
- else
- allkdann = new LSHSearch<>(referenceData, numProjs[projInd],
- numTables[tableInd], hashWidths[widthInd],
- secondHashSize, bucketSize);
+ Timer::Stop("hash_building");
- Timer::Stop("hash_building");
+ Log::Info << "Computing " << k << " distance approx. nearest neighbors " << endl;
+ allkann->Search(k, neighbors, distances);
- timeval start_tv = Timer::Get("computing_neighbors");
- double startTime = (double) start_tv.tv_sec + (double) start_tv.tv_usec / 1.0e6;
+ Log::Info << "Neighbors computed." << endl;
-
- Log::Info << "Computing " << k << " distance approx. nearest neighbors " << endl;
- allkdann->Search(k, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- timeval stop_tv = Timer::Get("computing_neighbors");
- double stopTime = (double) stop_tv.tv_sec + (double) stop_tv.tv_usec / 1.0e6;
-
-
- exptParams(exptInd, 0) = (double) numProjs[projInd];
- exptParams(exptInd, 1) = (double) numTables[tableInd];
- exptParams(exptInd, 2) = hashWidths[widthInd];
- timesTaken[exptInd] = stopTime - startTime;
-
- // add results to big matrix
- allNeighbors.rows(exptInd * k, (exptInd + 1) * k - 1) = neighbors;
-
- exptInd++;
-
- neighbors.reset();
- distances.reset();
-
- delete allkdann;
-
- } // diff. L
- } // diff. K
- } // diff. 'width'
-
-
- // TO FIX: Have to fix this since these matrices are getting reset.
// Save output.
if (distancesFile != "")
data::Save(distancesFile, distances);
@@ -237,138 +149,7 @@
if (neighborsFile != "")
data::Save(neighborsFile, neighbors);
+ delete allkann;
- // Compute the error if the error file is provided
- string rankFile = CLI::GetParam<string>("rank_file");
- if (rankFile != "")
- {
- Log::Warn << "Computing error..." << endl;
-
- contrib_utils::LineReader lr(rankFile);
-
- arma::mat allDANNErrors(timesTaken.n_elem, 8);
- // 0 - K
- // 1 - L
- // 2 - width
- // 3 - Time taken
- // 4 - Mean Rank/Recall
- // 5 - Median Rank/Recall
- // 6 - StdDev Rank/Recall
- // 7 - MaxRank / MinRecall
-
- if (k == 1)
- {
- string deFile = CLI::GetParam<string>("de_file");
-
- contrib_utils::LineReader *de_lr = NULL;
-
- if (deFile != "")
- {
- de_lr = new contrib_utils::LineReader(deFile);
- allDANNErrors.resize(timesTaken.n_elem, 12);
- // 8 - Mean DE
- // 9 - Median DE
- // 10 - StdDev DE
- // 11 - Max DE
- }
-
- arma::mat ranks(timesTaken.n_elem, allNeighbors.n_cols);
- arma::mat des(timesTaken.n_elem, allNeighbors.n_cols);
-
- for (size_t i = 0; i < allNeighbors.n_cols; i++)
- {
- arma::Col<size_t> true_ranks(referenceData.n_cols);
- lr.ReadLine(&true_ranks);
-
- arma::vec true_des(referenceData.n_cols);
- if (de_lr != NULL)
- de_lr->ReadLine(&true_des);
-
- for (size_t j = 0; j < timesTaken.n_elem; j++)
- {
- if (allNeighbors(j, i) < referenceData.n_cols)
- {
- ranks(j, i) = (double) true_ranks[allNeighbors(j, i)];
-
- if (de_lr != NULL)
- des(j, i) = true_des[allNeighbors(j, i)];
- }
- else
- {
- ranks(j, i) = (double) referenceData.n_cols;
-
- if (de_lr != NULL)
- des(j, i) = arma::max(true_des);
- }
- }
- }
-
- allDANNErrors.cols(0, 2) = exptParams;
- allDANNErrors.col(3) = timesTaken;
- allDANNErrors.col(4) = arma::mean(ranks, 1);
- allDANNErrors.col(5) = arma::median(ranks, 1);
- allDANNErrors.col(6) = arma::stddev(ranks, 1, 1);
- allDANNErrors.col(7) = arma::max(ranks, 1);
-
- if (de_lr != NULL)
- {
- allDANNErrors.col(8) = arma::mean(des, 1);
- allDANNErrors.col(9) = arma::median(des, 1);
- allDANNErrors.col(10) = arma::stddev(des, 1, 1);
- allDANNErrors.col(11) = arma::max(des, 1);
-
- delete de_lr;
- }
- } // if k == 1, compute rank error and distance error
- else
- {
- arma::mat recalls(timesTaken.n_elem, allNeighbors.n_cols);
- recalls.zeros();
-
- for (size_t i = 0; i < allNeighbors.n_cols; i++)
- {
- arma::Col<size_t> true_ranks(referenceData.n_cols);
- lr.ReadLine(&true_ranks);
-
- for (size_t j = 0; j < timesTaken.n_elem; j++)
- {
- for (size_t ind = 0; ind < k; ind++)
- if (allNeighbors(j * k + ind, i) < referenceData.n_cols)
- {
- if (true_ranks[allNeighbors(j * k + ind, i)] <= k)
- recalls(j, i)++;
- }
-
- }
- }
-
- recalls /= k;
-
- allDANNErrors.cols(0, 2) = exptParams;
- allDANNErrors.col(3) = timesTaken;
- allDANNErrors.col(4) = arma::mean(recalls, 1);
- allDANNErrors.col(5) = arma::median(recalls, 1);
- allDANNErrors.col(6) = arma::stddev(recalls, 1, 1);
- allDANNErrors.col(7) = arma::min(recalls, 1);
-
- } // if k > 1, compute recall of k-NN
-
- Log::Warn << allDANNErrors;
-
- string annErrorOutputFile = CLI::GetParam<string>("ann_error_file");
-
- if (annErrorOutputFile != "")
- {
- allDANNErrors = allDANNErrors.t();
- data::Save(annErrorOutputFile, allDANNErrors);
- }
- }
- else
- {
- Log::Warn << "Params: " << endl << exptParams.t()
- << "Times Taken: " << endl << timesTaken.t();
-
- }
-
return 0;
}
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,10 +2,24 @@
* @file lsh_search.hpp
* @author Parikshit Ram
*
- * Defines the LSHSearch class, which performs an abstract
- * distance-approximate nearest neighbor query on two datasets
- * using Locality-sensitive hashing with 2-stable distributions
+ * Defines the LSHSearch class, which performs an approximate
+ * nearest neighbor search for a queries in a query set
+ * over a given dataset using Locality-sensitive hashing
+ * with 2-stable distributions.
+ *
+ * The details of this method can be found in the following paper:
+ *
+ * @inproceedings{datar2004locality,
+ * title={Locality-sensitive hashing scheme based on p-stable distributions},
+ * author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
+ * booktitle={Proceedings of the 12th Annual Symposium on Computational Geometry},
+ * pages={253--262},A
+ * year={2004},
+ * organization={ACM}
+ * }
+ *
*/
+
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
@@ -22,57 +36,39 @@
* searches. */ {
/**
- * The RASearch class is a template class for performing distance-based
- * neighbor searches. It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy. A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
- *
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used. More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
+ * The LSHSearch class -- TBD
+ *
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
* @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use.
*/
template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename eT = double>
+ typename MetricType = mlpack::metric::SquaredEuclideanDistance>
class LSHSearch
{
public:
- typedef arma::Mat<eT> MatType;
- typedef arma::Col<eT> ColType;
- typedef arma::Row<eT> RowType;
-
/**
- * Initialize the RASearch object, passing both a query and reference
- * dataset. Optionally, perform the computation in naive mode or single-tree
- * mode, and set the leaf size used for tree-building. An initialized
- * distance metric can be given, for cases where the metric has internal data
- * (i.e. the distance::MahalanobisDistance class).
+ * Intialize -- TBD
*
- * This method will copy the matrices to internal copies, which are rearranged
- * during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a diferent constructor.
- *
* @param referenceSet Set of reference points.
* @param querySet Set of query points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param numProj Number of projections in each hash table (anything between
+ * 10-50 might be a decent choice).
+ * @param numTables Total number of hash tables (anything between 10-20 should
+ * should suffice).
+ * @param hashWidth The width of hash for every table (currently automatically
+ * chosen from the main function). This should be a reasonable upper bound
+ * on the nearest-neighbor distance in general.
+ * @param secondHashSize The size of the second hash table. This should be a
+ * large prime number.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
+ * Default values are already provided here.
* @param metric An optional instance of the MetricType class.
*/
- LSHSearch(const MatType& referenceSet,
- const MatType& querySet,
+ LSHSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
const size_t numProj,
const size_t numTables,
const double hashWidth,
@@ -81,27 +77,24 @@
const MetricType metric = MetricType());
/**
- * Initialize the RASearch object, passing only one dataset, which is
- * used as both the query and the reference dataset. Optionally, perform the
- * computation in naive mode or single-tree mode, and set the leaf size used
- * for tree-building. An initialized distance metric can be given, for cases
- * where the metric has internal data (i.e. the distance::MahalanobisDistance
- * class).
+ * Intialize -- TBD
*
- * If naive mode is being used and a pre-built tree is given, it may not work:
- * naive mode operates by building a one-node tree (the root node holds all
- * the points). If that condition is not satisfied with the pre-built tree,
- * then naive mode will not work.
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param referenceSet Set of reference points and the set of queries.
+ * @param numProj Number of projections in each hash table (anything between
+ * 10-50 might be a decent choice).
+ * @param numTables Total number of hash tables (anything between 10-20 should
+ * should suffice).
+ * @param hashWidth The width of hash for every table (currently automatically
+ * chosen from the main function). This should be a reasonable upper bound
+ * on the nearest-neighbor distance in general.
+ * @param secondHashSize The size of the second hash table. This should be a
+ * large prime number.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
+ * Default values are already provided here.
* @param metric An optional instance of the MetricType class.
*/
- LSHSearch(const MatType& referenceSet,
+ LSHSearch(const arma::mat& referenceSet,
const size_t numProj,
const size_t numTables,
const double hashWidth,
@@ -109,7 +102,7 @@
const size_t bucketSize = 500,
const MetricType metric = MetricType());
/**
- * Delete the RASearch object. The tree is the only member we are
+ * Delete the LSHSearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
~LSHSearch();
@@ -127,47 +120,86 @@
* point.
*/
void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
+ arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances);
private:
- void BuildFirstLevelHash(MatType* allKeyPointMat);
+ /**
+ * This function builds a hash table with two levels of hashing
+ * as presented in the paper. This function first hashes the points
+ * with 'numProj' random projections to a single hash table creating
+ * (key, point ID) pairs where the key is a 'numProj'-dimensional
+ * integer vector.
+ *
+ * Then each key in this hash table is hashed into a second hash table
+ * using a standard hash.
+ *
+ * This function does not have any parameters and relies on parameters
+ * which are private members of this class, intialized during the
+ * class intialization.
+ */
+ void BuildHash();
- void BuildSecondLevelHash(MatType& allKeyPointMat);
- inline void BaseCase(const size_t queryIndex,
- const size_t referenceIndex);
-
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-
+ /**
+ * This function takes a query and hashes it into each of the hash tables
+ * to get keys for the query and then the key is hashed to a bucket of the
+ * second hash table and all the points (if any) in those buckets
+ * are collected as the potential neighbor candidates.
+ *
+ * @param queryIndex The index of the query currently being processed.
+ * @param referenceIndices The list of neighbor candidates obtained from
+ * hashing the query into all the hash tables and eventually into
+ * multiple buckets of the second hash table.
+ */
void ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices);
+ /**
+ * This is a helper function that computes the distance of the query to the
+ * neighbor candidates and appropriately stores the best 'k' candidates
+ *
+ * @param queryIndex The index of the query in question
+ * @param referenceIndex The index of the neighbor candidate in question
+ */
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+ /**
+ * This is a helper function that efficiently inserts better neighbor
+ * candidates into an existing set of neighbor candidates. This function
+ * is only called by the 'BaseCase' function.
+ *
+ * @param queryIndex This is the index of the query being processed currently
+ * @param pos The position of the neighbor candidate in the current list of
+ * neighbor candidates.
+ * @param neighbor The neighbor candidate that is being inserted into the list
+ * of the best 'k' candidates for the query in question.
+ * @param distance The distance of the query to the neighbor candidate.
+ */
+ void InsertNeighbor(const size_t queryIndex, const size_t pos,
+ const size_t neighbor, const double distance);
private:
//! Reference dataset.
const arma::mat& referenceSet;
+
//! Query dataset (may not be given).
const arma::mat& querySet;
- //! Instantiation of kernel.
+ //! Instantiation of the metric.
MetricType metric;
//! The number of projections
const size_t numProj;
- //! The number of tables
+ //! The number of hash tables
const size_t numTables;
//! The std::vector containing the projection matrix of each table
- std::vector<MatType> projections; // should be [numProj x dims] x numTables
+ std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
//! The list of the offset 'b' for each of the projection for each table
- MatType offsets; // should be numProj x numTables
+ arma::mat offsets; // should be numProj x numTables
//! The hash width
const double hashWidth;
@@ -176,7 +208,7 @@
const size_t secondHashSize;
//! The weights of the second hash
- ColType secondHashWeights;
+ arma::vec secondHashWeights;
//! The bucket size of the second hash
const size_t bucketSize;
@@ -206,7 +238,4 @@
// Include implementation.
#include "lsh_search_impl.hpp"
-// Include convenience typedefs.
-//#include "lsh_typedef.hpp"
-
#endif
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,7 +2,7 @@
* @file lsh_search_impl.hpp
* @author Parikshit Ram
*
- * Implementation of LSHSearch class.
+ * Implementation of the LSHSearch class.
*/
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
@@ -14,10 +14,10 @@
using namespace mlpack::neighbor;
// Construct the object.
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
-LSHSearch(const MatType& referenceSet,
- const MatType& querySet,
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
+LSHSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
const size_t numProj,
const size_t numTables,
const double hashWidth,
@@ -33,17 +33,12 @@
bucketSize(bucketSize),
metric(metric)
{
- // Get a (N^K key, point index) pair for all tables and all points
- MatType allKeysPointsMat(numProj + 1, referenceSet.n_cols * numTables);
- BuildFirstLevelHash(&allKeysPointsMat);
-
- // Condense the (N^K key, point index) pairs into a single table
- BuildSecondLevelHash(allKeysPointsMat);
+ BuildHash();
}
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
-LSHSearch(const MatType& referenceSet,
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
+LSHSearch(const arma::mat& referenceSet,
const size_t numProj,
const size_t numTables,
const double hashWidth,
@@ -59,23 +54,18 @@
bucketSize(bucketSize),
metric(metric)
{
- // Get a (N^K key, point index) pair for all tables and all points
- MatType allKeysPointsMat(numProj + 1, referenceSet.n_cols * numTables);
- BuildFirstLevelHash(&allKeysPointsMat);
-
- // Condense the (N^K key, point index) pairs into a single table
- BuildSecondLevelHash(allKeysPointsMat);
+ BuildHash();
}
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
~LSHSearch()
{ }
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
InsertNeighbor(const size_t queryIndex,
const size_t pos,
const size_t neighbor,
@@ -100,18 +90,18 @@
-template<typename SortPolicy, typename MetricType, typename eT>
-inline void LSHSearch<SortPolicy, MetricType, eT>::
-BaseCase(const size_t queryIndex,
- const size_t referenceIndex)
+template<typename SortPolicy, typename MetricType>
+inline //force_inline
+double LSHSearch<SortPolicy, MetricType>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return;
+ return 0.0;
- double distance = metric.Evaluate(querySet.col(queryIndex),
- referenceSet.col(referenceIndex));
+ double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
@@ -122,33 +112,39 @@
if (insertPosition != (size_t() - 1))
InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
- return;
+ return distance;
}
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices)
{
+ // Hash the query in each of the 'numTables' hash tables using the
+ // 'numProj' projections for each table.
+ // This gives us 'numTables' keys for the query where each key
+ // is a 'numProj' dimensional integer vector
+ //
// compute the projection of the query in each table
- MatType allProjInTables(numProj, numTables);
-
+ arma::mat allProjInTables(numProj, numTables);
for (size_t i = 0; i < numTables; i++)
- allProjInTables.col(i) = projections[i].t() * querySet.col(queryIndex);
-
+ allProjInTables.unsafe_col(i)
+ = projections[i].t() * querySet.unsafe_col(queryIndex);
allProjInTables += offsets;
allProjInTables /= hashWidth;
- // compute the hash value of each projection of the query
- // in the second hash table
- RowType hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
+ // compute the hash value of each key of the query into a bucket of the
+ // 'secondHashTable' using the 'secondHashWeights'.
+ arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
- assert(hashVec.n_elem == numTables);
-
for (size_t i = 0; i < hashVec.n_elem; i++)
hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
+ assert(hashVec.n_elem == numTables);
+
+ // For all the buckets that the query is hashed into, sequentially
+ // collect the indices in those buckets.
arma::Col<size_t> refPointsConsidered;
refPointsConsidered.zeros(referenceSet.n_cols);
@@ -158,7 +154,7 @@
if (bucketContentSize[hashInd] > 0)
{
- // Pick the indices in that 'hashInd'
+ // Pick the indices in the bucket corresponding to 'hashInd'
size_t tableRow = bucketRowInHashTable[hashInd];
assert(tableRow < secondHashSize);
assert(tableRow < secondHashTable.n_rows);
@@ -169,14 +165,12 @@
} // for all tables
referenceIndices = arma::find(refPointsConsidered > 0);
-
return;
}
-
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances)
@@ -188,22 +182,26 @@
neighborPtr->set_size(k, querySet.n_cols);
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
- neighborPtr->fill((size_t) -1);
+ neighborPtr->fill(referenceSet.n_cols);
-
size_t avgIndicesReturned = 0;
-
Timer::Start("computing_neighbors");
- // go through every query point
+ // go through every query point sequentially
for (size_t i = 0; i < querySet.n_cols; i++)
{
+ // For hash every query into every hash tables and eventually
+ // into the 'secondHashTable' to obtain the neighbor candidates
arma::uvec refIndices;
ReturnIndicesFromTable(i, refIndices);
+ // Just an informative book-keeping for the number of neighbor candidates
+ // returned on average
avgIndicesReturned += refIndices.n_elem;
+ // Sequentially go through all the candidates and save the best 'k'
+ // candidates
for (size_t j = 0; j < refIndices.n_elem; j++)
BaseCase(i, (size_t) refIndices[j]);
}
@@ -211,106 +209,173 @@
Timer::Stop("computing_neighbors");
avgIndicesReturned /= querySet.n_cols;
- Log::Info << avgIndicesReturned << " distinct indices returned on average."
- << std::endl;
+ Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
+ std::endl;
return;
}
-
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
-BuildFirstLevelHash(MatType* allKeysPointsMat)
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
+BuildHash()
{
- // A row with all the indices
- RowType allIndRow(referenceSet.n_cols);
- for (size_t i = 0; i < allIndRow.n_elem; i++)
- allIndRow[i] = i;
+ // The first level hash for a single table outputs a 'numProj'-dimensional
+ // integer key for each point in the set -- (key, pointID)
+ // The key creation details are presented below
+ //
+ // The second level hash is performed by hashing the key to
+ // an integer in the range [0, 'secondHashSize').
+ //
+ // This is done by creating a weight vector 'secondHashWeights' of
+ // length 'numProj' with each entry an integer randomly chosen
+ // between [0, 'secondHashSize').
+ //
+ // Then the bucket for any key and its corresponding point is
+ // given by <key, 'secondHashWeights'> % 'secondHashSize'
+ // and the corresponding point ID is put into that bucket.
- // Obtain all the projection matrices and the offset matrix
- offsets.randu(numProj, numTables);
- offsets *= hashWidth;
+ //////////////////////////////////////////
+ // Step I: Preparing the second level hash
+ ///////////////////////////////////////////
- for(size_t i = 0; i < numTables; i++)
- {
- MatType projMat;
- projMat.randn(referenceSet.n_rows, numProj);
+ // obtain the weights for the second hash
+ secondHashWeights = arma::floor(arma::randu(numProj)
+ * (double) secondHashSize);
- MatType offsetMat = arma::repmat(offsets.col(i), 1, referenceSet.n_cols);
- MatType hashMat = projMat.t() * referenceSet;
+ // The 'secondHashTable' is initially an empty matrix of size
+ // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
+ // as points land in them allows us to shrink the size of the
+ // 'secondHashTable' at the end of the hashing.
- hashMat += offsetMat;
- hashMat /= hashWidth;
+ // Start filling up the second hash table
+ secondHashTable.set_size(secondHashSize, bucketSize);
- hashMat.resize(hashMat.n_rows + 1, hashMat.n_cols);
- hashMat.row(hashMat.n_rows - 1) = allIndRow;
+ // Fill the second hash table n = referenceSet.n_cols
+ // This is because no point has index 'n' so the presence of
+ // this in the bucket denotes that there are no more points
+ // in this bucket.
+ secondHashTable.fill(referenceSet.n_cols);
- allKeysPointsMat->cols(i * referenceSet.n_cols, (i + 1) * referenceSet.n_cols - 1)
- = arma::floor(hashMat);
+ // Keeping track of the size of each bucket in the hash.
+ // At the end of hashing most buckets will be empty.
+ bucketContentSize.zeros(secondHashSize);
- projections.push_back(projMat);
- } // loop over tables
+ // Instead of putting the points in the row corresponding to
+ // the bucket, we chose the next empty row and keep track of
+ // the row in which the bucket lies. This allows us to
+ // stack together and slice out the empty buckets at the
+ // end of the hashing.
+ bucketRowInHashTable.set_size(secondHashSize);
+ bucketRowInHashTable.fill(secondHashSize);
- return;
-}
+ // keeping track of number of non-empty rows in the 'secondHashTable'
+ size_t numRowsInTable = 0;
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
-BuildSecondLevelHash(MatType& allKeysPointsMat)
-{
- // obtain the hash weights for the second hash
- secondHashWeights = arma::floor(arma::randu(numProj)
- * (double) secondHashSize);
+ /////////////////////////////////////////////////////////
+ // Step II: The offsets for all projections in all tables
+ /////////////////////////////////////////////////////////
- RowType hashVec = secondHashWeights.t()
- * allKeysPointsMat.rows(0, numProj - 1);
+ // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
+ // as randu(numProj, numTables) * hashWidth
+ offsets.randu(numProj, numTables);
+ offsets *= hashWidth;
- for (size_t i = 0; i < hashVec.n_elem; i++)
- hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
+ /////////////////////////////////////////////////////////////////
+ // Step III: Creating each hash table in the first level hash
+ // one by one and putting them directly into the 'secondHashTable'
+ // for memory efficiency.
+ /////////////////////////////////////////////////////////////////
- assert(hashVec.n_elem == referenceSet.n_cols * numTables);
+ for(size_t i = 0; i < numTables; i++)
+ {
+ //////////////////////////////////////////////////////////////
+ // Step IV: Obtaining the 'numProj' projections for each table
+ //////////////////////////////////////////////////////////////
+ //
+ // For L2 metric, 2-stable distributions are used, and
+ // the normal Z ~ N(0, 1) is a 2-stable distribution.
+ arma::mat projMat;
+ projMat.randn(referenceSet.n_rows, numProj);
- // start filling up the second hash table;
- secondHashTable.set_size(secondHashSize, bucketSize);
- secondHashTable.fill(referenceSet.n_cols);
- bucketContentSize.zeros(secondHashSize);
+ // save the projection matrix for querying
+ projections.push_back(projMat);
- // Initializing to nothing
- bucketRowInHashTable.set_size(secondHashSize);
- bucketRowInHashTable.fill(secondHashSize);
+ ///////////////////////////////////////////////////////////////
+ // Step V: create the 'numProj'-dimensional key for each point
+ // in each table.
+ //////////////////////////////////////////////////////////////
- size_t numRowsInTable = 0;
+ // The following set of lines performs the task of
+ // hashing each point to a 'numProj'-dimensional integer key.
+ // Hence you get a ('numProj' x 'referenceSet.n_cols') key matrix
+ //
+ // For a single table, let the 'numProj' projections be denoted
+ // by 'proj_i' and the corresponding offset be 'offset_i'.
+ // Then the key of a single point is obtained as:
+ // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
+ arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i),
+ 1, referenceSet.n_cols);
+ arma::mat hashMat = projMat.t() * referenceSet;
+ hashMat += offsetMat;
+ hashMat /= hashWidth;
- for (size_t i = 0; i < hashVec.n_elem; i++)
- {
- size_t hashInd = (size_t) hashVec[i];
- size_t pointInd = (size_t) allKeysPointsMat(numProj, i);
+ ////////////////////////////////////////////////////////////
+ // Step VI: Putting the points in the 'secondHashTable' by
+ // hashing the key.
+ ///////////////////////////////////////////////////////////
- if (bucketContentSize[hashInd] == 0)
+ // Now we hash every key, point ID to its corresponding bucket
+ arma::rowvec secondHashVec = secondHashWeights.t()
+ * arma::floor(hashMat);
+
+ // This gives us the bucket for the corresponding point ID
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
+
+ assert(secondHashVec.n_elem == referenceSet.n_cols);
+
+ // Inserting the point in the corresponding row to its bucket
+ // in the 'secondHashTable'.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
{
- // start a new row for hash
- bucketRowInHashTable[hashInd] = numRowsInTable;
- secondHashTable(numRowsInTable, 0) = pointInd;
+ // This is the bucket number
+ size_t hashInd = (size_t) secondHashVec[j];
+ // The point ID is 'j'
- numRowsInTable++;
- }
- else
- {
- if (bucketContentSize[hashInd] < bucketSize)
+ // If this is currently an empty bucket, start a new row
+ // keep track of which row corresponds to the bucket.
+ if (bucketContentSize[hashInd] == 0)
{
- // continue with an existing row
- size_t tableRow = bucketRowInHashTable[hashInd];
- secondHashTable(tableRow, bucketContentSize[hashInd]) = pointInd;
+ // start a new row for hash
+ bucketRowInHashTable[hashInd] = numRowsInTable;
+ secondHashTable(numRowsInTable, 0) = j;
+
+ numRowsInTable++;
}
- // else just ignore as suggested
- }
+ // If bucket already present in the 'secondHashTable', find
+ // the corresponding row and insert the point ID in this row
+ // unless the bucket is full, in which case, do nothing.
+ else
+ {
+ // if bucket not full, insert point here
+ if (bucketContentSize[hashInd] < bucketSize)
+ secondHashTable(bucketRowInHashTable[hashInd],
+ bucketContentSize[hashInd]) = j;
+ // else just ignore as suggested
+ }
- if (bucketContentSize[hashInd] < bucketSize)
- bucketContentSize[hashInd]++;
- }
+ // increment the count of the points in this bucket
+ if (bucketContentSize[hashInd] < bucketSize)
+ bucketContentSize[hashInd]++;
+ } // loop over all points in the reference set
+ } // loop over tables
- // condense the second hash table
+
+ /////////////////////////////////////////////////
+ // Step VII: Condensing the 'secondHashTable'
+ /////////////////////////////////////////////////
+
size_t maxBucketSize = 0;
for (size_t i = 0; i < bucketContentSize.n_elem; i++)
if (bucketContentSize[i] > maxBucketSize)
@@ -323,4 +388,5 @@
return;
}
+
#endif
Added: mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp 2012-12-20 23:45:12 UTC (rev 14032)
@@ -0,0 +1,131 @@
+/**
+ * @file lsh_test.cpp
+ *
+ * Unit tests for the 'LSHSearch' class.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "lsh_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+PROGRAM_INFO("LSH test", " ");
+
+
+int main (int argc, char *argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+ math::RandomSeed(0);
+
+ arma::mat rdata(2, 10);
+ rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
+ 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
+
+
+ // Randomness present here -- seed = 0
+ // Computing the hashwidth here.
+ // CORRECT ANSWER: 'hashWidth' = 4.24777
+ double hashWidth = 0;
+ for (size_t i = 0; i < 10; i++)
+ {
+ size_t p1 = (size_t) math::RandInt(rdata.n_cols);
+ size_t p2 = (size_t) math::RandInt(rdata.n_cols);
+
+ hashWidth += metric::EuclideanDistance::Evaluate(rdata.unsafe_col(p1),
+ rdata.unsafe_col(p2));
+ }
+ hashWidth /= 10.0;
+
+ Log::Info << "Hash width: " << hashWidth << endl;
+
+ arma::mat qdata(2, 3);
+ qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
+
+ // INPUT TO LSH:
+ // Number of points: 10
+ // Number of dimensions: 2
+ // Number of projections per table: 'numProj' = 3
+ // Number of hash tables: 'numTables' = 2
+ // hashWidth (computed): 'hashWidth' = 4.24777
+ // Second hash size: 'secondHashSize' = 11
+ // Size of the bucket: 'bucketSize' = 3
+
+ // Randomness present in LSH -- seed = 0
+ // Things obtained by random sampling listed in the sequences
+ // as they will be obtained in the 'LSHSearch::BuildHash()' private function
+ // in 'LSHSearch' class.
+ //
+ // 1. The weights of the second hash obtained as:
+ // secondHashWeights = arma::floor(arma::randu(3) * 11.0);
+ // COR.SOL.: secondHashWeights = [9, 4, 8];
+ //
+ // 2. The offsets for all the 3 projections in each of the 2 tables:
+ // offsets.randu(3, 2)
+ // COR.SOL.: [0.7984 0.3352; 0.9116 0.7682; 0.1976 0.2778]
+ // offsets *= hashWidth
+ // COR.SOL.: [3.3916 1.4240; 3.8725 3.2633; 0.8392 1.1799]
+ //
+ // 3. The (2 x 3) projection matrices for the 2 tables:
+ // projMat.randn(2, 3)
+ // COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
+ // COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
+ LSHSearch<> *lsh_test = new LSHSearch<>(rdata, qdata, 3,2, hashWidth, 11,3);
+
+ // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
+ // COR.SOL.: [2 11 4 7 6 3 11 0 5 1 8]
+ //
+ // The 'LSHSearch::bucketContentSize' should be:
+ // COR.SOL.: [2 0 1 1 3 1 0 3 3 3 1]
+ //
+ // The final hash table 'LSHSearch::secondHashTable' should be
+ // of size (3 x 9) with the following content:
+ // COR.SOL.:
+ // [0 2 4; 1 7 8; 3 9 10; 5 10 10; 6 10 10; 0 5 6; 1 2 8; 3 10 10; 4 10 10]
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ lsh_test->Search(2, neighbors, distances);
+
+ // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
+ // should hash the query 0 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 7, Table 2 Bucket 0, refInds = [0 2 3 4 9]
+ //
+ // The private function 'LSHSearch::ReturnIndicesFromTable(1, refInds)'
+ // should hash the query 1 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 9, Table 2 Bucket 4, refInds = [1 2 7 8]
+ //
+ // The private function 'LSHSearch::ReturnIndicesFromTable(2, refInds)'
+ // should hash the query 2 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 0, Table 2 Bucket 7, refInds = [0 2 3 4 9]
+
+ // After search
+ // COR.SOL.: 'neighbors' = [2 1 9; 3 8 2]
+ // COR.SOL.: 'distances' = [2 0 2; 4 2 16]
+
+// Log::Info << "Neighbors: " << std::endl << neighbors << std::endl <<
+// "Distances: " << std::endl << distances << std::endl;
+
+ arma::Mat<size_t> true_neighbors(2, 3);
+ true_neighbors << 2 << 1 << 9 << arma::endr << 3 << 8 << 2 << arma::endr;
+ arma::mat true_distances(2, 3);
+ true_distances << 2 << 0 << 2 << arma::endr << 4 << 2 << 16 << arma::endr;
+
+ for (size_t i = 0; i < 3; i++)
+ {
+ for (size_t j = 0; j < 2; j++)
+ {
+ assert(neighbors(j, i) == true_neighbors(j, i));
+ assert(distances(j, i) == true_distances(j, i));
+ }
+ }
+
+ Log::Warn << "Expected neighbor results obtained!!" << std::endl;
+
+ delete lsh_test;
+
+ return 0;
+}
More information about the mlpack-svn
mailing list