[mlpack-svn] r14032 - mlpack/trunk/src/mlpack/methods/lsh

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 20 18:45:12 EST 2012


Author: rcurtin
Date: 2012-12-20 18:45:12 -0500 (Thu, 20 Dec 2012)
New Revision: 14032

Added:
   mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
Update to newest version (I mistakenly branched an old version).


Modified: mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt	2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/CMakeLists.txt	2012-12-20 23:45:12 UTC (rev 14032)
@@ -19,12 +19,34 @@
 set(MLPACK_CONTRIB_SRCS ${MLPACK_CONTRIB_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
 
-# The code to compute the rank-approximate neighbor
+# The code to compute the approximate neighbor
 # for the given query and reference sets
-add_executable(allkdann
+# with p-stable LSH
+add_executable(lsh
   lsh_main.cpp
 )
-target_link_libraries(allkdann
+target_link_libraries(lsh
   ${MLPACK_LIBRARY}
+)
+
+# The code to compute the approximate neighbor
+# for the given query and reference sets 
+# with p-stable LSH and analyze the time-accuracy
+# tradeoff
+add_executable(lsh_analysis
+  lsh_analysis_main.cpp
+)
+target_link_libraries(lsh_analysis
+  ${MLPACK_LIBRARY}
   contrib_pram
 )
+
+
+# Testing the LSHSearch class 
+add_executable(lsh_test
+  lsh_test.cpp
+)
+target_link_libraries(lsh_test
+  ${MLPACK_LIBRARY}
+)
+

Added: mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp	2012-12-20 23:45:12 UTC (rev 14032)
@@ -0,0 +1,367 @@
+/**
+ * @file lsh_analysis_main.cpp
+ * @author Parikshit Ram
+ *
+ * This main file computes the accuracy-time tradeoff of 
+ * the 2-stable LSH class 'LSHSearch'
+ */
+#include <time.h>
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "lsh_search.hpp"
+#include "../utils/utils.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+// Information about the program itself.
+PROGRAM_INFO("LSH accuracy-time tradeoff analysis.",
+    "This program will calculate the k approximate-nearest-neighbors "
+    "of a set of queries under different parameter settings."
+    " You may specify a separate set of reference "
+    "points and query points, or just a reference set which will be "
+    "used as both the reference and query set. "
+    "\n\n"
+    "For a given set of queries and references and 'k', this program "
+    "uses different parameters for LSH and computes the error and "
+    "the query time. For the computation of error, it requires the "
+    "files containing the ranks and the NN distances for each query. "
+    "\n\n"
+    "Sample usage: \n"
+    "./lsh/lsh_analysis -r reference.csv -q queries.csv -k 5 "
+    " -P -W -E query_ranks_file.csv -D query_nn_dist_file.csv "
+    " -F error_v_time_output.csv ");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+                 "r");
+PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_STRING_REQ("rank_file", "The file containing the true ranks.", "E");
+PARAM_STRING_REQ("nn_dist_file", "The file containing the true distance"
+                 " errors.", "D");
+PARAM_INT("num_projections", "The number of hash functions for each table",
+          "K", 10);
+PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
+PARAM_INT("second_hash_size", "The size of the second level hash table.", 
+          "M", 8807);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", 
+          "B", 500);
+
+PARAM_FLAG("try_diff_params", "The flag to trigger the search with "
+           "different 'K', 'L' and 'r'.", "P");
+PARAM_FLAG("try_diff_widths", "The flag to trigger the search with "
+           "different hash widths.", "W");
+
+PARAM_STRING("error_time_file", "File to output the error v. time report to.",
+             "F", "");
+
+int main(int argc, char *argv[])
+{
+  // Give CLI the command line parameters the user passed in.
+  CLI::ParseCommandLine(argc, argv);
+  math::RandomSeed(time(NULL)); 
+
+  // Get all the parameters.
+  string referenceFile = CLI::GetParam<string>("reference_file");
+  string distancesFile = CLI::GetParam<string>("distances_file");
+  string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+  size_t k = CLI::GetParam<int>("k");
+  size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
+  size_t bucketSize = CLI::GetParam<int>("bucket_size");
+
+  bool tryDiffParams = CLI::HasParam("try_diff_params");
+  bool tryDiffWidths = CLI::HasParam("try_diff_widths");
+
+  arma::mat referenceData;
+  arma::mat queryData; // So it doesn't go out of scope.
+  data::Load(referenceFile.c_str(), referenceData, true);
+
+  Log::Info << "Loaded reference data from '" << referenceFile << "' (" << 
+    referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+  // Sanity check on k value: must be greater than 0, must be less than the
+  // number of reference points.
+  if (k > referenceData.n_cols)
+  {
+    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+    Log::Fatal << "than or equal to the number of reference points (";
+    Log::Fatal << referenceData.n_cols << ")." << endl;
+  }
+
+
+  // Pick up the 'K' and the 'L' parameter for LSH
+  arma::Col<size_t> numProjs, numTables;
+  if (tryDiffParams) 
+  {
+    numProjs.set_size(4);
+    numProjs << 10 << 25 << 40 << 55;
+    numTables.set_size(4);
+    numTables << 5 << 10 << 15 << 20;
+  }
+  else 
+  {
+    numProjs.set_size(1);
+    numProjs[0] = CLI::GetParam<int>("num_projections");
+    numTables.set_size(1);
+    numTables[0] = CLI::GetParam<int>("num_tables");
+  }
+
+  // Compute the 'width' parameter from LSH
+
+  // Find the average pairwise distance of 25 random pairs
+  double avgDist = 0;
+  for (size_t i = 0; i < 25; i++)
+  {
+    size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
+    size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
+    avgDist += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
+                                                   referenceData.unsafe_col(p2));
+  }
+
+  avgDist /= 25;
+
+  Log::Info << "Hash width chosen as: " << avgDist << endl;
+
+  arma::vec hashWidths;
+  if (tryDiffWidths)
+  {
+    arma::vec eps(3);
+    eps << 0.01 << 0.1 << 1.0;
+    hashWidths = avgDist * eps;
+  }
+  else 
+  {
+    hashWidths.set_size(1);
+    hashWidths[0] = avgDist;
+  }
+
+  arma::vec timesTaken(numProjs.n_elem * numTables.n_elem * hashWidths.n_elem);
+  timesTaken.zeros();
+
+  arma::Mat<size_t> allNeighbors;
+  arma::mat allNeighborDistances;
+  
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  if (CLI::GetParam<string>("query_file") != "")
+  {
+    string queryFile = CLI::GetParam<string>("query_file");
+
+    data::Load(queryFile.c_str(), queryData, true);
+    Log::Info << "Loaded query data from '" << queryFile << "' (" << 
+      queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+    allNeighbors.set_size(k * timesTaken.n_elem, queryData.n_cols);
+    allNeighborDistances.set_size(k * timesTaken.n_elem, queryData.n_cols);
+  }
+  else
+  {
+    allNeighbors.set_size(k * timesTaken.n_elem, referenceData.n_cols);
+    allNeighborDistances.set_size(k * timesTaken.n_elem, referenceData.n_cols);
+  }
+
+  size_t exptInd = 0;
+  arma::mat exptParams(timesTaken.n_elem, 3);
+
+  // Looping through all combinations of the parameters and noting the 
+  // runtime and the results.
+  for (size_t widthInd = 0; widthInd < hashWidths.n_elem; widthInd++)
+  {
+    for (size_t projInd = 0; projInd < numProjs.n_elem; projInd++)
+    {
+      for (size_t tableInd = 0; tableInd < numTables.n_elem; tableInd++)
+      {
+        Log::Info << "LSH with K: " << numProjs[projInd] << ", L: " << 
+          numTables[tableInd] << ", r: " << hashWidths[widthInd] << endl;
+
+        Timer::Start("hash_building");
+
+        LSHSearch<>* allkann;
+
+        if (CLI::GetParam<string>("query_file") != "")
+          allkann = new LSHSearch<>(referenceData, queryData, numProjs[projInd],
+                                    numTables[tableInd], hashWidths[widthInd],
+                                    secondHashSize, bucketSize);
+        else
+          allkann = new LSHSearch<>(referenceData, numProjs[projInd],
+                                    numTables[tableInd], hashWidths[widthInd],
+                                    secondHashSize, bucketSize);
+
+        Timer::Stop("hash_building");
+
+        timeval start_tv = Timer::Get("computing_neighbors");
+        double startTime 
+          = (double) start_tv.tv_sec + (double) start_tv.tv_usec / 1.0e6;
+
+        Log::Info << "Computing " << k << " approx. nearest neighbors" << endl;
+        allkann->Search(k, neighbors, distances);
+        Log::Info << "Neighbors computed." << endl;
+
+        timeval stop_tv = Timer::Get("computing_neighbors");
+        double stopTime 
+          = (double) stop_tv.tv_sec + (double) stop_tv.tv_usec / 1.0e6;
+
+        exptParams(exptInd, 0) = (double) numProjs[projInd];
+        exptParams(exptInd, 1) = (double) numTables[tableInd];
+        exptParams(exptInd, 2) = hashWidths[widthInd];
+        timesTaken[exptInd] = stopTime - startTime;
+
+        // add results to big matrix
+        allNeighbors.rows(exptInd * k, (exptInd + 1) * k - 1) = neighbors;
+        allNeighborDistances.rows(exptInd * k, (exptInd + 1) * k - 1) 
+          = distances;
+
+        exptInd++;
+
+        neighbors.reset();
+        distances.reset();
+
+        delete allkann;
+      } // diff. L
+    } // diff. K
+  } // diff. 'width'
+
+
+  // Computing the errors
+  string rankFile = CLI::GetParam<string>("rank_file");
+  Log::Warn << "Computing error..." << endl;
+
+  contrib_utils::LineReader lr(rankFile);
+
+  arma::mat allANNErrors(timesTaken.n_elem, 8);
+  // 0 - K
+  // 1 - L
+  // 2 - width
+  // 3 - Time taken
+  // 4 - Mean Rank/Recall
+  // 5 - Median Rank/Recall
+  // 6 - StdDev Rank/Recall
+  // 7 - MaxRank / MinRecall
+
+
+  // If k == 1, compute the rank and distance errors
+  if (k == 1) 
+  {
+    string distFile = CLI::GetParam<string>("nn_dist_file");
+    arma::mat nnDistsMat;
+
+    if (!data::Load(distFile, nnDistsMat))
+      Log::Fatal << "Dist file " << distFile << " cannot be loaded." << endl;
+
+    arma::vec nnDists(nnDistsMat.row(0).t());
+
+    allANNErrors.resize(timesTaken.n_elem, 12);
+    // 8 - Mean DE
+    // 9 - Median DE
+    // 10 - StdDev DE
+    // 11 - Max DE
+
+    arma::mat ranks(timesTaken.n_elem, allNeighbors.n_cols);
+    arma::mat des(timesTaken.n_elem, allNeighbors.n_cols);
+
+    for (size_t i = 0; i < allNeighbors.n_cols; i++)
+    {
+      arma::Col<size_t> true_ranks(referenceData.n_cols);
+      lr.ReadLine(&true_ranks);
+
+      for (size_t j = 0; j < timesTaken.n_elem; j++) 
+      {
+        if (allNeighbors(j, i) < referenceData.n_cols)
+        {
+          ranks(j, i) = (double) true_ranks[allNeighbors(j, i)];
+          des(j, i) = (allNeighborDistances(j, i) - nnDists[i]) / nnDists[i];
+        }
+        // not sure what to do in terms of distance error 
+        // in case no result is returned
+        else 
+        {
+          ranks(j, i) = (double) referenceData.n_cols;
+          des(j, i) = 1000;
+        }
+      }
+    }
+
+    // Saving the LSH parameters and the query times
+    allANNErrors.cols(0, 2) = exptParams;
+    allANNErrors.col(3) = timesTaken;
+
+    // Saving the mean distance error and rank over all queries.
+    allANNErrors.col(4) = arma::mean(ranks, 1);
+    allANNErrors.col(5) = arma::median(ranks, 1);
+    allANNErrors.col(6) = arma::stddev(ranks, 1, 1);
+    allANNErrors.col(7) = arma::max(ranks, 1);
+
+    allANNErrors.col(8) = arma::mean(des, 1);
+    allANNErrors.col(9) = arma::median(des, 1);
+    allANNErrors.col(10) = arma::stddev(des, 1, 1);
+    allANNErrors.col(11) = arma::max(des, 1);
+
+  } // if k == 1
+  // if k > 1, compute the recall of the k-nearest-neighbors.
+  else
+  {
+    arma::mat recalls(timesTaken.n_elem, allNeighbors.n_cols);
+    recalls.zeros();
+
+    for (size_t i = 0; i < allNeighbors.n_cols; i++)
+    {
+      arma::Col<size_t> true_ranks(referenceData.n_cols);
+      lr.ReadLine(&true_ranks);
+
+      for (size_t j = 0; j < timesTaken.n_elem; j++)
+      {
+        for (size_t ind = 0; ind < k; ind++)
+        {
+          if (allNeighbors(j * k + ind, i) < referenceData.n_cols)
+          {
+            if (true_ranks[allNeighbors(j * k + ind, i)] <= k)
+              recalls(j, i)++;
+          }
+        }
+      }
+    }
+
+    recalls /= k;
+    
+    // Saving the LSH parameters and the query times
+    allANNErrors.cols(0, 2) = exptParams;
+    allANNErrors.col(3) = timesTaken;
+
+    // Saving the mean recall of the k-nearest-neighbor over all queries.
+    allANNErrors.col(4) = arma::mean(recalls, 1);
+    allANNErrors.col(5) = arma::median(recalls, 1);
+    allANNErrors.col(6) = arma::stddev(recalls, 1, 1);
+    allANNErrors.col(7) = arma::min(recalls, 1);
+    
+  } // if k > 1, compute recall of k-NN
+
+  Log::Warn << allANNErrors;
+
+
+  // Saving the output in a file
+  string annErrorOutputFile = CLI::GetParam<string>("error_time_file");
+
+  if (annErrorOutputFile != "")
+  {    
+    allANNErrors = allANNErrors.t();
+    data::Save(annErrorOutputFile, allANNErrors);
+  }    
+  else 
+  {
+    Log::Warn << "Params: " << endl << exptParams.t() << "Times Taken: " <<
+      endl << timesTaken.t();
+  }
+
+  return 0;
+}

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,8 +2,8 @@
  * @file lsh_main.cpp
  * @author Parikshit Ram
  *
- * Implementation of LSH with a 2-stable distribution of 
- * nearest neighbor search in Euclidean space. 
+ * This file computes the approximate nearest-neighbors using 2-stable 
+ * Locality-sensitive Hashing. 
  */
 #include <time.h>
 
@@ -15,15 +15,14 @@
 #include <iostream>
 
 #include "lsh_search.hpp"
-#include "../utils/utils.hpp"
 
 using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 
 // Information about the program itself.
-PROGRAM_INFO("All K-distance-Approximate-Nearest-Neighbors with LSH",
-    "This program will calculate the k distance-approximate-nearest-neighbors "
+PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
+    "This program will calculate the k approximate-nearest-neighbors "
     "of a set of points. You may specify a separate set of reference "
     "points and query points, or just a reference set which will be "
     "used as both the reference and query set. "
@@ -33,8 +32,7 @@
     "and store the distances in 'distances.csv' and the neighbors in the "
     "file 'neighbors.csv':"
     "\n\n"
-    "$ allkdann --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
-    "  --neighbors_file=neighbors.csv"
+    "$ ./lsh/lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
     "\n\n"
     "The output files are organized such that row i and column j in the "
     "neighbors output file corresponds to the index of the point in the "
@@ -46,25 +44,17 @@
 PARAM_STRING_REQ("reference_file", "File containing the reference dataset.", "r");
 PARAM_STRING("distances_file", "File to output distances into.", "d", "");
 PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-
 PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-
 PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
 
-PARAM_INT("num_projections", "The number of hash functions for each table", "K", 10);
+PARAM_INT("num_projections", "The number of hash functions for each table", 
+          "K", 10);
 PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", "M", 8807);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B", 500);
+PARAM_INT("second_hash_size", "The size of the second level hash table.", 
+          "M", 99901);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", 
+          "B", 500);
 
-PARAM_FLAG("try_diff_params", "The flag to trigger the search with "
-           "different 'K', 'L' and 'r'.", "P");
-PARAM_FLAG("try_diff_widths", "The flag to trigger the search with "
-           "different hash widths.", "W");
-
-PARAM_STRING("rank_file", "The file containing the true ranks.", "E", "");
-PARAM_STRING("de_file", "The file containing the true distance errors.", "D", "");
-PARAM_STRING("ann_error_file", "File to output the RANN errors to.", "F", "");
-
 int main(int argc, char *argv[])
 {
   // Give CLI the command line parameters the user passed in.
@@ -80,9 +70,6 @@
   size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
   size_t bucketSize = CLI::GetParam<int>("bucket_size");
 
-  bool tryDiffParams = CLI::HasParam("try_diff_params");
-  bool tryDiffWidths = CLI::HasParam("try_diff_widths");
-
   arma::mat referenceData;
   arma::mat queryData; // So it doesn't go out of scope.
   data::Load(referenceFile.c_str(), referenceData, true);
@@ -101,58 +88,27 @@
 
 
   // Pick up the 'K' and the 'L' parameter for LSH
-  arma::Col<size_t> numProjs, numTables;
-  if (tryDiffParams) 
-  {
-    numProjs.set_size(3);
-    numProjs << 10 << 25 << 50;
-    numTables.set_size(5);
-    numTables << 5 << 10 << 25 << 50 << 100;
-  }
-  else 
-  {
-    numProjs.set_size(1);
-    numProjs[0] = CLI::GetParam<int>("num_projections");
-    numTables.set_size(1);
-    numTables[0] = CLI::GetParam<int>("num_tables");
-  }
-
+  size_t numProj = CLI::GetParam<int>("num_projections");
+  size_t numTables = CLI::GetParam<int>("num_tables");
+  
   // Compute the 'width' parameter from LSH
 
-  // Find the average pairwise distance of 25 random pairs
-  double avgDist = 0;
+  // Find the average pairwise distance of 25 random pairs and use that 
+  // as the hash-width
+  double hashWidth = 0;
   for (size_t i = 0; i < 25; i++)
   {
-    size_t p1 = (size_t) math::RandInt(referenceData.n_cols),
-      p2 = (size_t) math::RandInt(referenceData.n_cols);
+    size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
+    size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
 
-    avgDist += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
-                                                   referenceData.unsafe_col(p2));
+    hashWidth += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
+                                                     referenceData.unsafe_col(p2));
   }
 
-  avgDist /= 25;
+  hashWidth /= 25;
 
-  Log::Info << "Hash width chosen as: " << avgDist << endl;
+  Log::Info << "Hash width chosen as: " << hashWidth << endl;
 
-  arma::vec hashWidths;
-  if (tryDiffWidths)
-  {
-    arma::vec eps(5);
-    eps << 0.001 << 0.01 << 0.1 << 1.0 << 10.0;
-    hashWidths = avgDist * eps;
-
-  }
-  else 
-  {
-    hashWidths.set_size(1);
-    hashWidths[0] = avgDist;
-  }
-
-  arma::vec timesTaken(numProjs.n_elem * numTables.n_elem * hashWidths.n_elem);
-  timesTaken.zeros();
-
-  arma::Mat<size_t> allNeighbors;
-  
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
@@ -163,73 +119,29 @@
     data::Load(queryFile.c_str(), queryData, true);
     Log::Info << "Loaded query data from '" << queryFile << "' ("
               << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
-    allNeighbors.set_size(k * timesTaken.n_elem, queryData.n_cols);
   }
-  else
-    allNeighbors.set_size(k * timesTaken.n_elem, referenceData.n_cols);
 
-  size_t exptInd = 0;
-  arma::mat exptParams(timesTaken.n_elem, 3);
+  Log::Info << "LSH with " << numProj << " projections(K) and " << numTables << 
+    " tables(L) with hash width(r): " << hashWidth << endl;
 
-  for (size_t widthInd = 0; widthInd < hashWidths.n_elem; widthInd++)
-  {
-    for (size_t projInd = 0; projInd < numProjs.n_elem; projInd++)
-    {
-      for (size_t tableInd = 0; tableInd < numTables.n_elem; tableInd++)
-      {
-        Log::Info << "LSH with K: " << numProjs[projInd] << ", L: " 
-                  << numTables[tableInd] << ", r: " << hashWidths[widthInd] << endl;
+  Timer::Start("hash_building");
 
-        Timer::Start("hash_building");
+  LSHSearch<>* allkann;
 
-        LSHSearch<>* allkdann;
+  if (CLI::GetParam<string>("query_file") != "")
+    allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
+                              hashWidth, secondHashSize, bucketSize);
+  else
+    allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
+                              secondHashSize, bucketSize);
 
-        if (CLI::GetParam<string>("query_file") != "")
-          allkdann = new LSHSearch<>(referenceData, queryData, numProjs[projInd],
-                                     numTables[tableInd], hashWidths[widthInd],
-                                     secondHashSize, bucketSize);
-        else
-          allkdann = new LSHSearch<>(referenceData, numProjs[projInd],
-                                     numTables[tableInd], hashWidths[widthInd],
-                                     secondHashSize, bucketSize);
+  Timer::Stop("hash_building");
 
-        Timer::Stop("hash_building");
+  Log::Info << "Computing " << k << " distance approx. nearest neighbors " << endl;
+  allkann->Search(k, neighbors, distances);
 
-        timeval start_tv = Timer::Get("computing_neighbors");
-        double startTime = (double) start_tv.tv_sec + (double) start_tv.tv_usec / 1.0e6;
+  Log::Info << "Neighbors computed." << endl;
 
-
-        Log::Info << "Computing " << k << " distance approx. nearest neighbors " << endl;
-        allkdann->Search(k, neighbors, distances);
-
-        Log::Info << "Neighbors computed." << endl;
-
-        timeval stop_tv = Timer::Get("computing_neighbors");
-        double stopTime = (double) stop_tv.tv_sec + (double) stop_tv.tv_usec / 1.0e6;
-
-
-        exptParams(exptInd, 0) = (double) numProjs[projInd];
-        exptParams(exptInd, 1) = (double) numTables[tableInd];
-        exptParams(exptInd, 2) = hashWidths[widthInd];
-        timesTaken[exptInd] = stopTime - startTime;
-
-        // add results to big matrix
-        allNeighbors.rows(exptInd * k, (exptInd + 1) * k - 1) = neighbors;
-
-        exptInd++;
-
-        neighbors.reset();
-        distances.reset();
-
-        delete allkdann;
-
-      } // diff. L
-    } // diff. K
-  } // diff. 'width'
-
-
-  // TO FIX: Have to fix this since these matrices are getting reset.
   // Save output.
   if (distancesFile != "") 
     data::Save(distancesFile, distances);
@@ -237,138 +149,7 @@
   if (neighborsFile != "")
     data::Save(neighborsFile, neighbors);
 
+  delete allkann;
 
-  // Compute the error if the error file is provided
-  string rankFile = CLI::GetParam<string>("rank_file");
-  if (rankFile != "")
-  {
-    Log::Warn << "Computing error..." << endl;
-
-    contrib_utils::LineReader lr(rankFile);
-
-    arma::mat allDANNErrors(timesTaken.n_elem, 8);
-    // 0 - K
-    // 1 - L
-    // 2 - width
-    // 3 - Time taken
-    // 4 - Mean Rank/Recall
-    // 5 - Median Rank/Recall
-    // 6 - StdDev Rank/Recall
-    // 7 - MaxRank / MinRecall
-
-    if (k == 1) 
-    {
-      string deFile = CLI::GetParam<string>("de_file");
-
-      contrib_utils::LineReader *de_lr = NULL;
-
-      if (deFile != "")
-      {
-        de_lr = new contrib_utils::LineReader(deFile);
-        allDANNErrors.resize(timesTaken.n_elem, 12);
-        // 8 - Mean DE
-        // 9 - Median DE
-        // 10 - StdDev DE
-        // 11 - Max DE
-      }
-
-      arma::mat ranks(timesTaken.n_elem, allNeighbors.n_cols);
-      arma::mat des(timesTaken.n_elem, allNeighbors.n_cols);
-
-      for (size_t i = 0; i < allNeighbors.n_cols; i++)
-      {
-        arma::Col<size_t> true_ranks(referenceData.n_cols);
-        lr.ReadLine(&true_ranks);
-
-        arma::vec true_des(referenceData.n_cols);
-        if (de_lr != NULL)
-          de_lr->ReadLine(&true_des);
-
-        for (size_t j = 0; j < timesTaken.n_elem; j++) 
-        {
-          if (allNeighbors(j, i) < referenceData.n_cols)
-          {
-            ranks(j, i) = (double) true_ranks[allNeighbors(j, i)];
-
-            if (de_lr != NULL)
-              des(j, i) = true_des[allNeighbors(j, i)];
-          }
-          else 
-          {
-            ranks(j, i) = (double) referenceData.n_cols;
-
-            if (de_lr != NULL)
-              des(j, i) = arma::max(true_des);
-          }
-        }
-      }
-
-      allDANNErrors.cols(0, 2) = exptParams;
-      allDANNErrors.col(3) = timesTaken;
-      allDANNErrors.col(4) = arma::mean(ranks, 1);
-      allDANNErrors.col(5) = arma::median(ranks, 1);
-      allDANNErrors.col(6) = arma::stddev(ranks, 1, 1);
-      allDANNErrors.col(7) = arma::max(ranks, 1);
-
-      if (de_lr != NULL) 
-      {
-        allDANNErrors.col(8) = arma::mean(des, 1);
-        allDANNErrors.col(9) = arma::median(des, 1);
-        allDANNErrors.col(10) = arma::stddev(des, 1, 1);
-        allDANNErrors.col(11) = arma::max(des, 1);
-
-        delete de_lr;
-      }
-    } // if k == 1, compute rank error and distance error
-    else
-    {
-      arma::mat recalls(timesTaken.n_elem, allNeighbors.n_cols);
-      recalls.zeros();
-
-      for (size_t i = 0; i < allNeighbors.n_cols; i++)
-      {
-        arma::Col<size_t> true_ranks(referenceData.n_cols);
-        lr.ReadLine(&true_ranks);
-
-        for (size_t j = 0; j < timesTaken.n_elem; j++)
-        {
-          for (size_t ind = 0; ind < k; ind++)
-            if (allNeighbors(j * k + ind, i) < referenceData.n_cols)
-            {
-              if (true_ranks[allNeighbors(j * k + ind, i)] <= k)
-                recalls(j, i)++;
-            }
-          
-        }
-      }
-
-      recalls /= k;
-    
-      allDANNErrors.cols(0, 2) = exptParams;
-      allDANNErrors.col(3) = timesTaken;
-      allDANNErrors.col(4) = arma::mean(recalls, 1);
-      allDANNErrors.col(5) = arma::median(recalls, 1);
-      allDANNErrors.col(6) = arma::stddev(recalls, 1, 1);
-      allDANNErrors.col(7) = arma::min(recalls, 1);
-    
-    } // if k > 1, compute recall of k-NN
-
-    Log::Warn << allDANNErrors;
-
-    string annErrorOutputFile = CLI::GetParam<string>("ann_error_file");
-
-    if (annErrorOutputFile != "")
-    {    
-      allDANNErrors = allDANNErrors.t();
-      data::Save(annErrorOutputFile, allDANNErrors);
-    }    
-  }
-  else 
-  {
-    Log::Warn << "Params: " << endl << exptParams.t()
-              << "Times Taken: "  << endl << timesTaken.t();
-
-  }
-
   return 0;
 }

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,10 +2,24 @@
  * @file lsh_search.hpp
  * @author Parikshit Ram
  *
- * Defines the LSHSearch class, which performs an abstract
- * distance-approximate nearest neighbor query on two datasets
- * using Locality-sensitive hashing with 2-stable distributions
+ * Defines the LSHSearch class, which performs an approximate
+ * nearest neighbor search for a queries in a query set 
+ * over a given dataset using Locality-sensitive hashing 
+ * with 2-stable distributions.
+ *
+ * The details of this method can be found in the following paper:
+ *
+ * @inproceedings{datar2004locality,
+ *  title={Locality-sensitive hashing scheme based on p-stable distributions},
+ *  author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
+ *  booktitle={Proceedings of the 12th Annual Symposium on Computational Geometry},
+ *  pages={253--262},A
+ *  year={2004},
+ *  organization={ACM}
+ * }
+ *
  */
+
 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
 
@@ -22,57 +36,39 @@
                     * searches. */ {
 
 /**
- * The RASearch class is a template class for performing distance-based
- * neighbor searches.  It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy.  A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
- *
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used.  More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
+ * The LSHSearch class -- TBD
+ * 
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
  * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use.
  */
 template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::SquaredEuclideanDistance,
-         typename eT = double>
+         typename MetricType = mlpack::metric::SquaredEuclideanDistance>
 
 class LSHSearch
 {
  public:
 
-  typedef arma::Mat<eT> MatType;
-  typedef arma::Col<eT> ColType;
-  typedef arma::Row<eT> RowType;
-
   /**
-   * Initialize the RASearch object, passing both a query and reference
-   * dataset.  Optionally, perform the computation in naive mode or single-tree
-   * mode, and set the leaf size used for tree-building.  An initialized
-   * distance metric can be given, for cases where the metric has internal data
-   * (i.e. the distance::MahalanobisDistance class).
+   * Intialize -- TBD
    *
-   * This method will copy the matrices to internal copies, which are rearranged
-   * during tree-building.  You can avoid this extra copy by pre-constructing
-   * the trees and passing them using a diferent constructor.
-   *
    * @param referenceSet Set of reference points.
    * @param querySet Set of query points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
+   * @param numProj Number of projections in each hash table (anything between 
+   *     10-50 might be a decent choice).
+   * @param numTables Total number of hash tables (anything between 10-20 should 
+   *     should suffice).
+   * @param hashWidth The width of hash for every table (currently automatically 
+   *     chosen from the main function). This should be a reasonable upper bound 
+   *     on the nearest-neighbor distance in general.
+   * @param secondHashSize The size of the second hash table. This should be a 
+   *     large prime number.
+   * @param bucketSize The size of the bucket in the second hash table. This is 
+   *     the maximum number of points that can be hashed into single bucket. 
+   *     Default values are already provided here.
    * @param metric An optional instance of the MetricType class.
    */
-  LSHSearch(const MatType& referenceSet,
-            const MatType& querySet,
+  LSHSearch(const arma::mat& referenceSet,
+            const arma::mat& querySet,
             const size_t numProj,
             const size_t numTables,
             const double hashWidth,
@@ -81,27 +77,24 @@
             const MetricType metric = MetricType());
 
   /**
-   * Initialize the RASearch object, passing only one dataset, which is
-   * used as both the query and the reference dataset.  Optionally, perform the
-   * computation in naive mode or single-tree mode, and set the leaf size used
-   * for tree-building.  An initialized distance metric can be given, for cases
-   * where the metric has internal data (i.e. the distance::MahalanobisDistance
-   * class).
+   * Intialize -- TBD
    *
-   * If naive mode is being used and a pre-built tree is given, it may not work:
-   * naive mode operates by building a one-node tree (the root node holds all
-   * the points).  If that condition is not satisfied with the pre-built tree,
-   * then naive mode will not work.
-   *
-   * @param referenceSet Set of reference points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
+   * @param referenceSet Set of reference points and the set of queries.
+   * @param numProj Number of projections in each hash table (anything between 
+   *     10-50 might be a decent choice).
+   * @param numTables Total number of hash tables (anything between 10-20 should 
+   *     should suffice).
+   * @param hashWidth The width of hash for every table (currently automatically 
+   *     chosen from the main function). This should be a reasonable upper bound 
+   *     on the nearest-neighbor distance in general.
+   * @param secondHashSize The size of the second hash table. This should be a 
+   *     large prime number.
+   * @param bucketSize The size of the bucket in the second hash table. This is 
+   *     the maximum number of points that can be hashed into single bucket. 
+   *     Default values are already provided here.
    * @param metric An optional instance of the MetricType class.
    */
-  LSHSearch(const MatType& referenceSet,
+  LSHSearch(const arma::mat& referenceSet,
             const size_t numProj,
             const size_t numTables,
             const double hashWidth,
@@ -109,7 +102,7 @@
             const size_t bucketSize = 500,
             const MetricType metric = MetricType());
   /**
-   * Delete the RASearch object. The tree is the only member we are
+   * Delete the LSHSearch object. The tree is the only member we are
    * responsible for deleting.  The others will take care of themselves.
    */
   ~LSHSearch();
@@ -127,47 +120,86 @@
    *     point.
    */
   void Search(const size_t k,
-              arma::Mat<size_t>& resultingNeighbors,
+              arma::Mat<size_t>& resultingNeighbors, 
               arma::mat& distances);
 
  private:
 
-  void BuildFirstLevelHash(MatType* allKeyPointMat);
+  /**
+   * This function builds a hash table with two levels of hashing 
+   * as presented in the paper. This function first hashes the points
+   * with 'numProj' random projections to a single hash table creating
+   * (key, point ID) pairs where the key is a 'numProj'-dimensional 
+   * integer vector.
+   * 
+   * Then each key in this hash table is hashed into a second hash table
+   * using a standard hash. 
+   *
+   * This function does not have any parameters and relies on parameters 
+   * which are private members of this class, intialized during the 
+   * class intialization.
+   */
+  void BuildHash();
 
-  void BuildSecondLevelHash(MatType& allKeyPointMat);
 
-  inline void BaseCase(const size_t queryIndex, 
-                       const size_t referenceIndex);
-
-  void InsertNeighbor(const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance);
-
+  /**
+   * This function takes a query and hashes it into each of the hash tables 
+   * to get keys for the query and then the key is hashed to a bucket of the 
+   * second hash table and all the points (if any) in those buckets 
+   * are collected as the potential neighbor candidates.
+   *
+   * @param queryIndex The index of the query currently being processed.
+   * @param referenceIndices The list of neighbor candidates obtained from 
+   *    hashing the query into all the hash tables and eventually into 
+   *    multiple buckets of the second hash table.
+   */
   void ReturnIndicesFromTable(const size_t queryIndex,
                               arma::uvec& referenceIndices);
+  /**
+   * This is a helper function that computes the distance of the query to the 
+   * neighbor candidates and appropriately stores the best 'k' candidates
+   *
+   * @param queryIndex The index of the query in question
+   * @param referenceIndex The index of the neighbor candidate in question
+   */
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
+  /**
+   * This is a helper function that efficiently inserts better neighbor 
+   * candidates into an existing set of neighbor candidates. This function 
+   * is only called by the 'BaseCase' function. 
+   *
+   * @param queryIndex This is the index of the query being processed currently
+   * @param pos The position of the neighbor candidate in the current list of 
+   *    neighbor candidates.
+   * @param neighbor The neighbor candidate that is being inserted into the list 
+   *    of the best 'k' candidates for the query in question.
+   * @param distance The distance of the query to the neighbor candidate. 
+   */
+  void InsertNeighbor(const size_t queryIndex, const size_t pos,
+                      const size_t neighbor, const double distance);
 
  private:
   //! Reference dataset.
   const arma::mat& referenceSet;
+
   //! Query dataset (may not be given).
   const arma::mat& querySet;
 
-  //! Instantiation of kernel.
+  //! Instantiation of the metric.
   MetricType metric;
 
   //! The number of projections
   const size_t numProj;
 
-  //! The number of tables
+  //! The number of hash tables
   const size_t numTables;
 
   //! The std::vector containing the projection matrix of each table
-  std::vector<MatType> projections; // should be [numProj x dims] x numTables
+  std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
 
   //! The list of the offset 'b' for each of the projection for each table
-  MatType offsets; // should be numProj x numTables
+  arma::mat offsets; // should be numProj x numTables
 
   //! The hash width
   const double hashWidth;
@@ -176,7 +208,7 @@
   const size_t secondHashSize;
 
   //! The weights of the second hash
-  ColType secondHashWeights;
+  arma::vec secondHashWeights;
 
   //! The bucket size of the second hash
   const size_t bucketSize;
@@ -206,7 +238,4 @@
 // Include implementation.
 #include "lsh_search_impl.hpp"
 
-// Include convenience typedefs.
-//#include "lsh_typedef.hpp"
-
 #endif

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2012-12-20 23:29:46 UTC (rev 14031)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2012-12-20 23:45:12 UTC (rev 14032)
@@ -2,7 +2,7 @@
  * @file lsh_search_impl.hpp
  * @author Parikshit Ram
  *
- * Implementation of LSHSearch class.
+ * Implementation of the LSHSearch class.
  */
 #ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
@@ -14,10 +14,10 @@
 using namespace mlpack::neighbor;
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
-LSHSearch(const MatType& referenceSet,
-          const MatType& querySet,
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
+LSHSearch(const arma::mat& referenceSet,
+          const arma::mat& querySet,
           const size_t numProj,
           const size_t numTables,
           const double hashWidth,
@@ -33,17 +33,12 @@
   bucketSize(bucketSize),
   metric(metric)
 {
-  // Get a (N^K key, point index) pair for all tables and all points
-  MatType allKeysPointsMat(numProj + 1, referenceSet.n_cols * numTables);
-  BuildFirstLevelHash(&allKeysPointsMat);
-
-  // Condense the (N^K key, point index) pairs into a single table
-  BuildSecondLevelHash(allKeysPointsMat);
+  BuildHash();
 }
 
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
-LSHSearch(const MatType& referenceSet,
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
+LSHSearch(const arma::mat& referenceSet,
           const size_t numProj,
           const size_t numTables,
           const double hashWidth,
@@ -59,23 +54,18 @@
   bucketSize(bucketSize),
   metric(metric)
 {
-  // Get a (N^K key, point index) pair for all tables and all points
-  MatType allKeysPointsMat(numProj + 1, referenceSet.n_cols * numTables);
-  BuildFirstLevelHash(&allKeysPointsMat);
-
-  // Condense the (N^K key, point index) pairs into a single table
-  BuildSecondLevelHash(allKeysPointsMat);
+  BuildHash();
 }
 
 
-template<typename SortPolicy, typename MetricType, typename eT>
-LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+LSHSearch<SortPolicy, MetricType>::
 ~LSHSearch()
 { }
 
 
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
 InsertNeighbor(const size_t queryIndex,
                const size_t pos,
                const size_t neighbor,
@@ -100,18 +90,18 @@
 
 
 
-template<typename SortPolicy, typename MetricType, typename eT>
-inline void LSHSearch<SortPolicy, MetricType, eT>::
-BaseCase(const size_t queryIndex,
-         const size_t referenceIndex)
+template<typename SortPolicy, typename MetricType>
+inline //force_inline 
+double LSHSearch<SortPolicy, MetricType>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
   if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
-    return;
+    return 0.0;
 
-  double distance = metric.Evaluate(querySet.col(queryIndex),
-                                    referenceSet.col(referenceIndex));
+  double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+                                    referenceSet.unsafe_col(referenceIndex));
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
@@ -122,33 +112,39 @@
   if (insertPosition != (size_t() - 1))
     InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
 
-  return;
+  return distance;
 }
 
 
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
 ReturnIndicesFromTable(const size_t queryIndex,
                        arma::uvec& referenceIndices)
 {
+  // Hash the query in each of the 'numTables' hash tables using the 
+  // 'numProj' projections for each table.
+  // This gives us 'numTables' keys for the query where each key
+  // is a 'numProj' dimensional integer vector
+  //
   // compute the projection of the query in each table
-  MatType allProjInTables(numProj, numTables);
-
+  arma::mat allProjInTables(numProj, numTables);
   for (size_t i = 0; i < numTables; i++)
-    allProjInTables.col(i) = projections[i].t() * querySet.col(queryIndex);
-
+    allProjInTables.unsafe_col(i) 
+      = projections[i].t() * querySet.unsafe_col(queryIndex);
   allProjInTables += offsets;
   allProjInTables /= hashWidth;
 
-  // compute the hash value of each projection of the query
-  // in the second hash table
-  RowType hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
+  // compute the hash value of each key of the query into a bucket of the 
+  // 'secondHashTable' using the 'secondHashWeights'. 
+  arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
 
-  assert(hashVec.n_elem == numTables);
-
   for (size_t i = 0; i < hashVec.n_elem; i++)
     hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
 
+  assert(hashVec.n_elem == numTables);
+
+  // For all the buckets that the query is hashed into, sequentially 
+  // collect the indices in those buckets. 
   arma::Col<size_t> refPointsConsidered;
   refPointsConsidered.zeros(referenceSet.n_cols);
 
@@ -158,7 +154,7 @@
 
     if (bucketContentSize[hashInd] > 0)
     {
-      // Pick the indices in that 'hashInd'
+      // Pick the indices in the bucket corresponding to 'hashInd'
       size_t tableRow = bucketRowInHashTable[hashInd];
       assert(tableRow < secondHashSize);
       assert(tableRow < secondHashTable.n_rows);
@@ -169,14 +165,12 @@
   } // for all tables
 
   referenceIndices = arma::find(refPointsConsidered > 0);
-
   return;
 }
 
 
-
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
 Search(const size_t k,
        arma::Mat<size_t>& resultingNeighbors,
        arma::mat& distances)
@@ -188,22 +182,26 @@
   neighborPtr->set_size(k, querySet.n_cols);
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
-  neighborPtr->fill((size_t) -1);
+  neighborPtr->fill(referenceSet.n_cols);
 
-
   size_t avgIndicesReturned = 0;
 
-
   Timer::Start("computing_neighbors");
 
-  // go through every query point
+  // go through every query point sequentially
   for (size_t i = 0; i < querySet.n_cols; i++)
   {
+    // For hash every query into every hash tables and eventually 
+    // into the 'secondHashTable' to obtain the neighbor candidates
     arma::uvec refIndices;
     ReturnIndicesFromTable(i, refIndices);
 
+    // Just an informative book-keeping for the number of neighbor candidates
+    // returned on average
     avgIndicesReturned += refIndices.n_elem;
 
+    // Sequentially go through all the candidates and save the best 'k'
+    // candidates
     for (size_t j = 0; j < refIndices.n_elem; j++)
       BaseCase(i, (size_t) refIndices[j]);
   }
@@ -211,106 +209,173 @@
   Timer::Stop("computing_neighbors");
 
   avgIndicesReturned /= querySet.n_cols;
-  Log::Info << avgIndicesReturned << " distinct indices returned on average." 
-            << std::endl;
+  Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
+    std::endl;
 
   return;
 }
 
-
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
-BuildFirstLevelHash(MatType* allKeysPointsMat)
+template<typename SortPolicy, typename MetricType>
+void LSHSearch<SortPolicy, MetricType>::
+BuildHash()
 {
-  // A row with all the indices
-  RowType allIndRow(referenceSet.n_cols);
-  for (size_t i = 0; i < allIndRow.n_elem; i++)
-    allIndRow[i] = i;
+  // The first level hash for a single table outputs a 'numProj'-dimensional
+  // integer key for each point in the set -- (key, pointID)
+  // The key creation details are presented below
+  // 
+  // The second level hash is performed by hashing the key to 
+  // an integer in the range [0, 'secondHashSize').
+  //
+  // This is done by creating a weight vector 'secondHashWeights' of 
+  // length 'numProj' with each entry an integer randomly chosen
+  // between [0, 'secondHashSize'). 
+  //
+  // Then the bucket for any key and its corresponding point is 
+  // given by <key, 'secondHashWeights'> % 'secondHashSize'
+  // and the corresponding point ID is put into that bucket.  
 
-  // Obtain all the projection matrices and the offset matrix
-  offsets.randu(numProj, numTables);
-  offsets *= hashWidth;
+  //////////////////////////////////////////
+  // Step I: Preparing the second level hash
+  ///////////////////////////////////////////
 
-  for(size_t i = 0; i < numTables; i++)
-  {
-    MatType projMat;
-    projMat.randn(referenceSet.n_rows, numProj);
+  // obtain the weights for the second hash
+  secondHashWeights = arma::floor(arma::randu(numProj) 
+                                  * (double) secondHashSize);
 
-    MatType offsetMat = arma::repmat(offsets.col(i), 1, referenceSet.n_cols);
-    MatType hashMat = projMat.t() * referenceSet;
+  // The 'secondHashTable' is initially an empty matrix of size 
+  // ('secondHashSize' x 'bucketSize'). But by only filling the buckets 
+  // as points land in them allows us to shrink the size of the 
+  // 'secondHashTable' at the end of the hashing.
 
-    hashMat += offsetMat;
-    hashMat /= hashWidth;
+  // Start filling up the second hash table
+  secondHashTable.set_size(secondHashSize, bucketSize);
 
-    hashMat.resize(hashMat.n_rows + 1, hashMat.n_cols);
-    hashMat.row(hashMat.n_rows - 1) = allIndRow;
+  // Fill the second hash table n = referenceSet.n_cols
+  // This is because no point has index 'n' so the presence of 
+  // this in the bucket denotes that there are no more points 
+  // in this bucket. 
+  secondHashTable.fill(referenceSet.n_cols);
 
-    allKeysPointsMat->cols(i * referenceSet.n_cols, (i + 1) * referenceSet.n_cols - 1)
-      = arma::floor(hashMat);
+  // Keeping track of the size of each bucket in the hash.
+  // At the end of hashing most buckets will be empty.
+  bucketContentSize.zeros(secondHashSize);
 
-    projections.push_back(projMat);
-  } // loop over tables
+  // Instead of putting the points in the row corresponding to 
+  // the bucket, we chose the next empty row and keep track of 
+  // the row in which the bucket lies. This allows us to 
+  // stack together and slice out the empty buckets at the 
+  // end of the hashing.
+  bucketRowInHashTable.set_size(secondHashSize);
+  bucketRowInHashTable.fill(secondHashSize);
 
-  return;
-}
+  // keeping track of number of non-empty rows in the 'secondHashTable'
+  size_t numRowsInTable = 0;
 
 
-template<typename SortPolicy, typename MetricType, typename eT>
-void LSHSearch<SortPolicy, MetricType, eT>::
-BuildSecondLevelHash(MatType& allKeysPointsMat)
-{
-  // obtain the hash weights for the second hash
-  secondHashWeights = arma::floor(arma::randu(numProj) 
-                                  * (double) secondHashSize);
+  /////////////////////////////////////////////////////////
+  // Step II: The offsets for all projections in all tables
+  /////////////////////////////////////////////////////////
 
-  RowType hashVec = secondHashWeights.t() 
-    * allKeysPointsMat.rows(0, numProj - 1);
+  // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
+  // as randu(numProj, numTables) * hashWidth
+  offsets.randu(numProj, numTables);
+  offsets *= hashWidth;
 
-  for (size_t i = 0; i < hashVec.n_elem; i++)
-    hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
+  /////////////////////////////////////////////////////////////////
+  // Step III: Creating each hash table in the first level hash
+  // one by one and putting them directly into the 'secondHashTable'
+  // for memory efficiency.
+  /////////////////////////////////////////////////////////////////
 
-  assert(hashVec.n_elem == referenceSet.n_cols * numTables);
+  for(size_t i = 0; i < numTables; i++)
+  {
+    //////////////////////////////////////////////////////////////
+    // Step IV: Obtaining the 'numProj' projections for each table
+    //////////////////////////////////////////////////////////////
+    //
+    // For L2 metric, 2-stable distributions are used, and 
+    // the normal Z ~ N(0, 1) is a 2-stable distribution.
+    arma::mat projMat;
+    projMat.randn(referenceSet.n_rows, numProj);
 
-  // start filling up the second hash table;
-  secondHashTable.set_size(secondHashSize, bucketSize);
-  secondHashTable.fill(referenceSet.n_cols);
-  bucketContentSize.zeros(secondHashSize);
+    // save the projection matrix for querying
+    projections.push_back(projMat);
 
-  // Initializing to nothing
-  bucketRowInHashTable.set_size(secondHashSize);
-  bucketRowInHashTable.fill(secondHashSize);
+    ///////////////////////////////////////////////////////////////
+    // Step V: create the 'numProj'-dimensional key for each point 
+    // in each table.
+    //////////////////////////////////////////////////////////////
 
-  size_t numRowsInTable = 0;
+    // The following set of lines performs the task of 
+    // hashing each point to a 'numProj'-dimensional integer key. 
+    // Hence you get a ('numProj' x 'referenceSet.n_cols') key matrix
+    //
+    // For a single table, let the 'numProj' projections be denoted 
+    // by 'proj_i' and the corresponding offset be 'offset_i'. 
+    // Then the key of a single point is obtained as:
+    // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
+    arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i),
+                                       1, referenceSet.n_cols);
+    arma::mat hashMat = projMat.t() * referenceSet;
+    hashMat += offsetMat;
+    hashMat /= hashWidth;
 
-  for (size_t i = 0; i < hashVec.n_elem; i++)
-  {
-    size_t hashInd = (size_t) hashVec[i];
-    size_t pointInd = (size_t) allKeysPointsMat(numProj, i);
+    ////////////////////////////////////////////////////////////
+    // Step VI: Putting the points in the 'secondHashTable' by 
+    // hashing the key.
+    ///////////////////////////////////////////////////////////
 
-    if (bucketContentSize[hashInd] == 0)
+    // Now we hash every key, point ID to its corresponding bucket
+    arma::rowvec secondHashVec = secondHashWeights.t()
+      * arma::floor(hashMat);
+
+    // This gives us the bucket for the corresponding point ID
+    for (size_t j = 0; j < secondHashVec.n_elem; j++)
+      secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
+
+    assert(secondHashVec.n_elem == referenceSet.n_cols);
+
+    // Inserting the point in the corresponding row to its bucket 
+    // in the 'secondHashTable'.
+    for (size_t j = 0; j < secondHashVec.n_elem; j++)
     {
-      // start a new row for hash
-      bucketRowInHashTable[hashInd] = numRowsInTable;
-      secondHashTable(numRowsInTable, 0) = pointInd;
+      // This is the bucket number
+      size_t hashInd = (size_t) secondHashVec[j];
+      // The point ID is 'j'
 
-      numRowsInTable++;
-    }
-    else 
-    {
-      if (bucketContentSize[hashInd] < bucketSize)
+      // If this is currently an empty bucket, start a new row 
+      // keep track of which row corresponds to the bucket.
+      if (bucketContentSize[hashInd] == 0)
       {
-        // continue with an existing row
-        size_t tableRow = bucketRowInHashTable[hashInd];
-        secondHashTable(tableRow, bucketContentSize[hashInd]) = pointInd;
+        // start a new row for hash
+        bucketRowInHashTable[hashInd] = numRowsInTable;
+        secondHashTable(numRowsInTable, 0) = j;
+
+        numRowsInTable++;
       }
-      // else just ignore as suggested
-    }
+      // If bucket already present in the 'secondHashTable', find 
+      // the corresponding row and insert the point ID in this row 
+      // unless the bucket is full, in which case, do nothing.
+      else 
+      {
+        // if bucket not full, insert point here        
+        if (bucketContentSize[hashInd] < bucketSize)
+          secondHashTable(bucketRowInHashTable[hashInd], 
+                          bucketContentSize[hashInd]) = j;
+        // else just ignore as suggested
+      }
 
-    if (bucketContentSize[hashInd] < bucketSize)
-      bucketContentSize[hashInd]++;
-  }
+      // increment the count of the points in this bucket
+      if (bucketContentSize[hashInd] < bucketSize)
+        bucketContentSize[hashInd]++;
+    } // loop over all points in the reference set
+  } // loop over tables
 
-  // condense the second hash table
+
+  /////////////////////////////////////////////////
+  // Step VII: Condensing the 'secondHashTable'
+  /////////////////////////////////////////////////
+
   size_t maxBucketSize = 0; 
   for (size_t i = 0; i < bucketContentSize.n_elem; i++)
     if (bucketContentSize[i] > maxBucketSize)
@@ -323,4 +388,5 @@
   return;
 }
 
+
 #endif

Added: mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_test.cpp	2012-12-20 23:45:12 UTC (rev 14032)
@@ -0,0 +1,131 @@
+/**
+ * @file lsh_test.cpp
+ *
+ * Unit tests for the 'LSHSearch' class.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "lsh_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+PROGRAM_INFO("LSH test", " ");
+
+
+int main (int argc, char *argv[])
+{
+  CLI::ParseCommandLine(argc, argv);
+  math::RandomSeed(0);
+
+  arma::mat rdata(2, 10);
+  rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr << 
+    0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
+
+
+  // Randomness present here -- seed = 0
+  // Computing the hashwidth here.
+  // CORRECT ANSWER: 'hashWidth' = 4.24777
+  double hashWidth = 0;
+  for (size_t i = 0; i < 10; i++)
+  {
+    size_t p1 = (size_t) math::RandInt(rdata.n_cols);
+    size_t p2 = (size_t) math::RandInt(rdata.n_cols);
+
+    hashWidth += metric::EuclideanDistance::Evaluate(rdata.unsafe_col(p1),
+                                                     rdata.unsafe_col(p2));
+  }
+  hashWidth /= 10.0;
+
+  Log::Info << "Hash width: " << hashWidth << endl;
+
+  arma::mat qdata(2, 3);
+  qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
+
+  // INPUT TO LSH:
+  // Number of points: 10
+  // Number of dimensions: 2
+  // Number of projections per table: 'numProj' = 3
+  // Number of hash tables: 'numTables' = 2
+  // hashWidth (computed): 'hashWidth' = 4.24777
+  // Second hash size: 'secondHashSize' = 11
+  // Size of the bucket: 'bucketSize' = 3
+
+  // Randomness present in LSH -- seed = 0
+  // Things obtained by random sampling listed in the sequences 
+  // as they will be obtained in the 'LSHSearch::BuildHash()' private function 
+  // in 'LSHSearch' class.
+  //
+  // 1. The weights of the second hash obtained as:
+  //    secondHashWeights = arma::floor(arma::randu(3) * 11.0);
+  //    COR.SOL.: secondHashWeights = [9, 4, 8];
+  //
+  // 2. The offsets for all the 3 projections in each of the 2 tables:
+  //    offsets.randu(3, 2)
+  //    COR.SOL.: [0.7984 0.3352; 0.9116 0.7682; 0.1976 0.2778]
+  //    offsets *= hashWidth
+  //    COR.SOL.: [3.3916 1.4240; 3.8725 3.2633; 0.8392 1.1799]
+  //
+  // 3. The  (2 x 3) projection matrices for the 2 tables:
+  //    projMat.randn(2, 3)
+  //    COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
+  //    COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
+  LSHSearch<> *lsh_test = new LSHSearch<>(rdata, qdata, 3,2, hashWidth, 11,3);
+
+  // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
+  // COR.SOL.: [2 11 4 7 6 3 11 0 5 1 8]
+  //
+  // The 'LSHSearch::bucketContentSize' should be:
+  // COR.SOL.: [2 0 1 1 3 1 0 3 3 3 1]
+  // 
+  // The final hash table 'LSHSearch::secondHashTable' should be 
+  // of size (3 x 9) with the following content: 
+  // COR.SOL.: 
+  // [0 2 4; 1 7 8; 3 9 10; 5 10 10; 6 10 10; 0 5 6; 1 2 8; 3 10 10; 4 10 10]
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  lsh_test->Search(2, neighbors, distances);
+
+  // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
+  // should hash the query 0 into the following buckets:
+  // COR.SOL.: Table 1 Bucket 7, Table 2 Bucket 0, refInds = [0 2 3 4 9]
+  //
+  // The private function 'LSHSearch::ReturnIndicesFromTable(1, refInds)'
+  // should hash the query 1 into the following buckets:
+  // COR.SOL.: Table 1 Bucket 9, Table 2 Bucket 4, refInds = [1 2 7 8]
+  //
+  // The private function 'LSHSearch::ReturnIndicesFromTable(2, refInds)'
+  // should hash the query 2 into the following buckets:
+  // COR.SOL.: Table 1 Bucket 0, Table 2 Bucket 7, refInds = [0 2 3 4 9]
+
+  // After search
+  // COR.SOL.: 'neighbors' = [2 1 9; 3 8 2]
+  // COR.SOL.: 'distances' = [2 0 2; 4 2 16]
+
+//   Log::Info << "Neighbors: " << std::endl << neighbors << std::endl <<
+//     "Distances: " << std::endl << distances << std::endl;
+
+  arma::Mat<size_t> true_neighbors(2, 3);
+  true_neighbors << 2 << 1 << 9 << arma::endr << 3 << 8 << 2 << arma::endr;
+  arma::mat true_distances(2, 3);
+  true_distances << 2 << 0 << 2 << arma::endr << 4 << 2 << 16 << arma::endr;
+
+  for (size_t i = 0; i < 3; i++)
+  {
+    for (size_t j = 0; j < 2; j++)
+    {
+      assert(neighbors(j, i) == true_neighbors(j, i));
+      assert(distances(j, i) == true_distances(j, i));
+    }
+  }
+
+  Log::Warn << "Expected neighbor results obtained!!" << std::endl;
+
+  delete lsh_test;
+
+  return 0;
+}




More information about the mlpack-svn mailing list