[mlpack-svn] r14096 - mlpack/trunk/src/mlpack/methods/lsh

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jan 9 15:45:21 EST 2013


Author: pram
Date: 2013-01-09 15:45:20 -0500 (Wed, 09 Jan 2013)
New Revision: 14096

Removed:
   mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
LSH class updated: Hash width computation moved within the class and removed from the main file. More comments added to the LSH class. Search function made tunable allowing the user to chose the number of hash table he/she choses to search in.

Deleted: mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp	2013-01-09 20:44:07 UTC (rev 14095)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_analysis_main.cpp	2013-01-09 20:45:20 UTC (rev 14096)
@@ -1,367 +0,0 @@
-/**
- * @file lsh_analysis_main.cpp
- * @author Parikshit Ram
- *
- * This main file computes the accuracy-time tradeoff of 
- * the 2-stable LSH class 'LSHSearch'
- */
-#include <time.h>
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "lsh_search.hpp"
-#include "../utils/utils.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-// Information about the program itself.
-PROGRAM_INFO("LSH accuracy-time tradeoff analysis.",
-    "This program will calculate the k approximate-nearest-neighbors "
-    "of a set of queries under different parameter settings."
-    " You may specify a separate set of reference "
-    "points and query points, or just a reference set which will be "
-    "used as both the reference and query set. "
-    "\n\n"
-    "For a given set of queries and references and 'k', this program "
-    "uses different parameters for LSH and computes the error and "
-    "the query time. For the computation of error, it requires the "
-    "files containing the ranks and the NN distances for each query. "
-    "\n\n"
-    "Sample usage: \n"
-    "./lsh/lsh_analysis -r reference.csv -q queries.csv -k 5 "
-    " -P -W -E query_ranks_file.csv -D query_nn_dist_file.csv "
-    " -F error_v_time_output.csv ");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
-                 "r");
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_STRING_REQ("rank_file", "The file containing the true ranks.", "E");
-PARAM_STRING_REQ("nn_dist_file", "The file containing the true distance"
-                 " errors.", "D");
-PARAM_INT("num_projections", "The number of hash functions for each table",
-          "K", 10);
-PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", 
-          "M", 8807);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", 
-          "B", 500);
-
-PARAM_FLAG("try_diff_params", "The flag to trigger the search with "
-           "different 'K', 'L' and 'r'.", "P");
-PARAM_FLAG("try_diff_widths", "The flag to trigger the search with "
-           "different hash widths.", "W");
-
-PARAM_STRING("error_time_file", "File to output the error v. time report to.",
-             "F", "");
-
-int main(int argc, char *argv[])
-{
-  // Give CLI the command line parameters the user passed in.
-  CLI::ParseCommandLine(argc, argv);
-  math::RandomSeed(time(NULL)); 
-
-  // Get all the parameters.
-  string referenceFile = CLI::GetParam<string>("reference_file");
-  string distancesFile = CLI::GetParam<string>("distances_file");
-  string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
-  size_t k = CLI::GetParam<int>("k");
-  size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
-  size_t bucketSize = CLI::GetParam<int>("bucket_size");
-
-  bool tryDiffParams = CLI::HasParam("try_diff_params");
-  bool tryDiffWidths = CLI::HasParam("try_diff_widths");
-
-  arma::mat referenceData;
-  arma::mat queryData; // So it doesn't go out of scope.
-  data::Load(referenceFile.c_str(), referenceData, true);
-
-  Log::Info << "Loaded reference data from '" << referenceFile << "' (" << 
-    referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
-  // Sanity check on k value: must be greater than 0, must be less than the
-  // number of reference points.
-  if (k > referenceData.n_cols)
-  {
-    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
-    Log::Fatal << "than or equal to the number of reference points (";
-    Log::Fatal << referenceData.n_cols << ")." << endl;
-  }
-
-
-  // Pick up the 'K' and the 'L' parameter for LSH
-  arma::Col<size_t> numProjs, numTables;
-  if (tryDiffParams) 
-  {
-    numProjs.set_size(4);
-    numProjs << 10 << 25 << 40 << 55;
-    numTables.set_size(4);
-    numTables << 5 << 10 << 15 << 20;
-  }
-  else 
-  {
-    numProjs.set_size(1);
-    numProjs[0] = CLI::GetParam<int>("num_projections");
-    numTables.set_size(1);
-    numTables[0] = CLI::GetParam<int>("num_tables");
-  }
-
-  // Compute the 'width' parameter from LSH
-
-  // Find the average pairwise distance of 25 random pairs
-  double avgDist = 0;
-  for (size_t i = 0; i < 25; i++)
-  {
-    size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
-    size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
-    avgDist += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
-                                                   referenceData.unsafe_col(p2));
-  }
-
-  avgDist /= 25;
-
-  Log::Info << "Hash width chosen as: " << avgDist << endl;
-
-  arma::vec hashWidths;
-  if (tryDiffWidths)
-  {
-    arma::vec eps(3);
-    eps << 0.01 << 0.1 << 1.0;
-    hashWidths = avgDist * eps;
-  }
-  else 
-  {
-    hashWidths.set_size(1);
-    hashWidths[0] = avgDist;
-  }
-
-  arma::vec timesTaken(numProjs.n_elem * numTables.n_elem * hashWidths.n_elem);
-  timesTaken.zeros();
-
-  arma::Mat<size_t> allNeighbors;
-  arma::mat allNeighborDistances;
-  
-  arma::Mat<size_t> neighbors;
-  arma::mat distances;
-
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    string queryFile = CLI::GetParam<string>("query_file");
-
-    data::Load(queryFile.c_str(), queryData, true);
-    Log::Info << "Loaded query data from '" << queryFile << "' (" << 
-      queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
-    allNeighbors.set_size(k * timesTaken.n_elem, queryData.n_cols);
-    allNeighborDistances.set_size(k * timesTaken.n_elem, queryData.n_cols);
-  }
-  else
-  {
-    allNeighbors.set_size(k * timesTaken.n_elem, referenceData.n_cols);
-    allNeighborDistances.set_size(k * timesTaken.n_elem, referenceData.n_cols);
-  }
-
-  size_t exptInd = 0;
-  arma::mat exptParams(timesTaken.n_elem, 3);
-
-  // Looping through all combinations of the parameters and noting the 
-  // runtime and the results.
-  for (size_t widthInd = 0; widthInd < hashWidths.n_elem; widthInd++)
-  {
-    for (size_t projInd = 0; projInd < numProjs.n_elem; projInd++)
-    {
-      for (size_t tableInd = 0; tableInd < numTables.n_elem; tableInd++)
-      {
-        Log::Info << "LSH with K: " << numProjs[projInd] << ", L: " << 
-          numTables[tableInd] << ", r: " << hashWidths[widthInd] << endl;
-
-        Timer::Start("hash_building");
-
-        LSHSearch<>* allkann;
-
-        if (CLI::GetParam<string>("query_file") != "")
-          allkann = new LSHSearch<>(referenceData, queryData, numProjs[projInd],
-                                    numTables[tableInd], hashWidths[widthInd],
-                                    secondHashSize, bucketSize);
-        else
-          allkann = new LSHSearch<>(referenceData, numProjs[projInd],
-                                    numTables[tableInd], hashWidths[widthInd],
-                                    secondHashSize, bucketSize);
-
-        Timer::Stop("hash_building");
-
-        timeval start_tv = Timer::Get("computing_neighbors");
-        double startTime 
-          = (double) start_tv.tv_sec + (double) start_tv.tv_usec / 1.0e6;
-
-        Log::Info << "Computing " << k << " approx. nearest neighbors" << endl;
-        allkann->Search(k, neighbors, distances);
-        Log::Info << "Neighbors computed." << endl;
-
-        timeval stop_tv = Timer::Get("computing_neighbors");
-        double stopTime 
-          = (double) stop_tv.tv_sec + (double) stop_tv.tv_usec / 1.0e6;
-
-        exptParams(exptInd, 0) = (double) numProjs[projInd];
-        exptParams(exptInd, 1) = (double) numTables[tableInd];
-        exptParams(exptInd, 2) = hashWidths[widthInd];
-        timesTaken[exptInd] = stopTime - startTime;
-
-        // add results to big matrix
-        allNeighbors.rows(exptInd * k, (exptInd + 1) * k - 1) = neighbors;
-        allNeighborDistances.rows(exptInd * k, (exptInd + 1) * k - 1) 
-          = distances;
-
-        exptInd++;
-
-        neighbors.reset();
-        distances.reset();
-
-        delete allkann;
-      } // diff. L
-    } // diff. K
-  } // diff. 'width'
-
-
-  // Computing the errors
-  string rankFile = CLI::GetParam<string>("rank_file");
-  Log::Warn << "Computing error..." << endl;
-
-  contrib_utils::LineReader lr(rankFile);
-
-  arma::mat allANNErrors(timesTaken.n_elem, 8);
-  // 0 - K
-  // 1 - L
-  // 2 - width
-  // 3 - Time taken
-  // 4 - Mean Rank/Recall
-  // 5 - Median Rank/Recall
-  // 6 - StdDev Rank/Recall
-  // 7 - MaxRank / MinRecall
-
-
-  // If k == 1, compute the rank and distance errors
-  if (k == 1) 
-  {
-    string distFile = CLI::GetParam<string>("nn_dist_file");
-    arma::mat nnDistsMat;
-
-    if (!data::Load(distFile, nnDistsMat))
-      Log::Fatal << "Dist file " << distFile << " cannot be loaded." << endl;
-
-    arma::vec nnDists(nnDistsMat.row(0).t());
-
-    allANNErrors.resize(timesTaken.n_elem, 12);
-    // 8 - Mean DE
-    // 9 - Median DE
-    // 10 - StdDev DE
-    // 11 - Max DE
-
-    arma::mat ranks(timesTaken.n_elem, allNeighbors.n_cols);
-    arma::mat des(timesTaken.n_elem, allNeighbors.n_cols);
-
-    for (size_t i = 0; i < allNeighbors.n_cols; i++)
-    {
-      arma::Col<size_t> true_ranks(referenceData.n_cols);
-      lr.ReadLine(&true_ranks);
-
-      for (size_t j = 0; j < timesTaken.n_elem; j++) 
-      {
-        if (allNeighbors(j, i) < referenceData.n_cols)
-        {
-          ranks(j, i) = (double) true_ranks[allNeighbors(j, i)];
-          des(j, i) = (allNeighborDistances(j, i) - nnDists[i]) / nnDists[i];
-        }
-        // not sure what to do in terms of distance error 
-        // in case no result is returned
-        else 
-        {
-          ranks(j, i) = (double) referenceData.n_cols;
-          des(j, i) = 1000;
-        }
-      }
-    }
-
-    // Saving the LSH parameters and the query times
-    allANNErrors.cols(0, 2) = exptParams;
-    allANNErrors.col(3) = timesTaken;
-
-    // Saving the mean distance error and rank over all queries.
-    allANNErrors.col(4) = arma::mean(ranks, 1);
-    allANNErrors.col(5) = arma::median(ranks, 1);
-    allANNErrors.col(6) = arma::stddev(ranks, 1, 1);
-    allANNErrors.col(7) = arma::max(ranks, 1);
-
-    allANNErrors.col(8) = arma::mean(des, 1);
-    allANNErrors.col(9) = arma::median(des, 1);
-    allANNErrors.col(10) = arma::stddev(des, 1, 1);
-    allANNErrors.col(11) = arma::max(des, 1);
-
-  } // if k == 1
-  // if k > 1, compute the recall of the k-nearest-neighbors.
-  else
-  {
-    arma::mat recalls(timesTaken.n_elem, allNeighbors.n_cols);
-    recalls.zeros();
-
-    for (size_t i = 0; i < allNeighbors.n_cols; i++)
-    {
-      arma::Col<size_t> true_ranks(referenceData.n_cols);
-      lr.ReadLine(&true_ranks);
-
-      for (size_t j = 0; j < timesTaken.n_elem; j++)
-      {
-        for (size_t ind = 0; ind < k; ind++)
-        {
-          if (allNeighbors(j * k + ind, i) < referenceData.n_cols)
-          {
-            if (true_ranks[allNeighbors(j * k + ind, i)] <= k)
-              recalls(j, i)++;
-          }
-        }
-      }
-    }
-
-    recalls /= k;
-    
-    // Saving the LSH parameters and the query times
-    allANNErrors.cols(0, 2) = exptParams;
-    allANNErrors.col(3) = timesTaken;
-
-    // Saving the mean recall of the k-nearest-neighbor over all queries.
-    allANNErrors.col(4) = arma::mean(recalls, 1);
-    allANNErrors.col(5) = arma::median(recalls, 1);
-    allANNErrors.col(6) = arma::stddev(recalls, 1, 1);
-    allANNErrors.col(7) = arma::min(recalls, 1);
-    
-  } // if k > 1, compute recall of k-NN
-
-  Log::Warn << allANNErrors;
-
-
-  // Saving the output in a file
-  string annErrorOutputFile = CLI::GetParam<string>("error_time_file");
-
-  if (annErrorOutputFile != "")
-  {    
-    allANNErrors = allANNErrors.t();
-    data::Save(annErrorOutputFile, allANNErrors);
-  }    
-  else 
-  {
-    Log::Warn << "Params: " << endl << exptParams.t() << "Times Taken: " <<
-      endl << timesTaken.t();
-  }
-
-  return 0;
-}

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-09 20:44:07 UTC (rev 14095)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-09 20:45:20 UTC (rev 14096)
@@ -50,6 +50,9 @@
 PARAM_INT("num_projections", "The number of hash functions for each table", 
           "K", 10);
 PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
+PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing "
+             "in the LSH preprocessing. By default, the LSH class "
+             "automatically estimates a hash width for its use.", "H", 0.0);
 PARAM_INT("second_hash_size", "The size of the second level hash table.", 
           "M", 99901);
 PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", 
@@ -91,24 +94,10 @@
   size_t numProj = CLI::GetParam<int>("num_projections");
   size_t numTables = CLI::GetParam<int>("num_tables");
   
-  // Compute the 'width' parameter from LSH
+  // Compute the 'hash_width' parameter from LSH
+  double hashWidth = CLI::GetParam<double>("hash_width");
 
-  // Find the average pairwise distance of 25 random pairs and use that 
-  // as the hash-width
-  double hashWidth = 0;
-  for (size_t i = 0; i < 25; i++)
-  {
-    size_t p1 = (size_t) math::RandInt(referenceData.n_cols);
-    size_t p2 = (size_t) math::RandInt(referenceData.n_cols);
 
-    hashWidth += metric::EuclideanDistance::Evaluate(referenceData.unsafe_col(p1),
-                                                     referenceData.unsafe_col(p2));
-  }
-
-  hashWidth /= 25;
-
-  Log::Info << "Hash width chosen as: " << hashWidth << endl;
-
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
@@ -121,8 +110,12 @@
               << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
   }
 
-  Log::Info << "LSH with " << numProj << " projections(K) and " << numTables << 
-    " tables(L) with hash width(r): " << hashWidth << endl;
+  if (hashWidth == 0.0)
+    Log::Info << "LSH with " << numProj << " projections(K) and " << 
+      numTables << " tables(L) with default hash width." << endl;
+  else
+    Log::Info << "LSH with " << numProj << " projections(K) and " << 
+      numTables << " tables(L) with hash width(r): " << hashWidth << endl;
 
   Timer::Start("hash_building");
 
@@ -136,8 +129,9 @@
                               secondHashSize, bucketSize);
 
   Timer::Stop("hash_building");
-
-  Log::Info << "Computing " << k << " distance approx. nearest neighbors " << endl;
+  
+  Log::Info << "Computing " << k << " distance approx. nearest neighbors " << 
+    endl;
   allkann->Search(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-09 20:44:07 UTC (rev 14095)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-09 20:45:20 UTC (rev 14096)
@@ -34,7 +34,9 @@
 namespace neighbor {
 
 /**
- * The LSHSearch class -- TBD
+ * The LSHSearch class -- This class builds a hash on the reference set 
+ * and uses this hash to compute the distance-approximate nearest-neighbors 
+ * of the given queries.
  *
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
  * @tparam MetricType The metric to use for computation.
@@ -45,17 +47,21 @@
 {
  public:
   /**
-   * Intialize -- TBD
+   * This function initializes the LSH class. It builds the hash on the 
+   * reference set with 2-stable distributions. See the individual functions 
+   * performing the hashing for details on how the hashing is done.
    *
    * @param referenceSet Set of reference points.
    * @param querySet Set of query points.
    * @param numProj Number of projections in each hash table (anything between
    *     10-50 might be a decent choice).
-   * @param numTables Total number of hash tables (anything between 10-20 should
+   * @param numTables Total number of hash tables (anything between 10-20 
    *     should suffice).
-   * @param hashWidth The width of hash for every table (currently automatically
-   *     chosen from the main function). This should be a reasonable upper bound
-   *     on the nearest-neighbor distance in general.
+   * @param hashWidth The width of hash for every table. If the user does not 
+   *     provide a value then the class automatically obtains a hash width
+   *     by computing the average pairwise distance of 25 pairs. This should 
+   *     be a reasonable upper bound on the nearest-neighbor distance 
+   *     in general.
    * @param secondHashSize The size of the second hash table. This should be a
    *     large prime number.
    * @param bucketSize The size of the bucket in the second hash table. This is
@@ -67,22 +73,26 @@
             const arma::mat& querySet,
             const size_t numProj,
             const size_t numTables,
-            const double hashWidth,
+            const double hashWidth = 0.0,
             const size_t secondHashSize = 99901,
             const size_t bucketSize = 500,
             const MetricType metric = MetricType());
 
   /**
-   * Intialize -- TBD
+   * This function initializes the LSH class. It builds the hash on the 
+   * reference set with 2-stable distributions. See the individual functions 
+   * performing the hashing for details on how the hashing is done.
    *
    * @param referenceSet Set of reference points and the set of queries.
    * @param numProj Number of projections in each hash table (anything between
    *     10-50 might be a decent choice).
-   * @param numTables Total number of hash tables (anything between 10-20 should
+   * @param numTables Total number of hash tables (anything between 10-20 
    *     should suffice).
-   * @param hashWidth The width of hash for every table (currently automatically
-   *     chosen from the main function). This should be a reasonable upper bound
-   *     on the nearest-neighbor distance in general.
+   * @param hashWidth The width of hash for every table. If the user does not 
+   *     provide a value then the class automatically obtains a hash width
+   *     by computing the average pairwise distance of 25 pairs. This should 
+   *     be a reasonable upper bound on the nearest-neighbor distance 
+   *     in general.
    * @param secondHashSize The size of the second hash table. This should be a
    *     large prime number.
    * @param bucketSize The size of the bucket in the second hash table. This is
@@ -93,7 +103,7 @@
   LSHSearch(const arma::mat& referenceSet,
             const size_t numProj,
             const size_t numTables,
-            const double hashWidth,
+            const double hashWidth = 0.0,
             const size_t secondHashSize = 99901,
             const size_t bucketSize = 500,
             const MetricType metric = MetricType());
@@ -105,8 +115,8 @@
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
-   * The matrices will be set to the size of n columns by k rows, where n is the
-   * number of points in the query dataset and k is the number of neighbors
+   * The matrices will be set to the size of n columns by k rows, where n is 
+   * the number of points in the query dataset and k is the number of neighbors
    * being searched for.
    *
    * @param k Number of neighbors to search for.
@@ -114,10 +124,17 @@
    *     point.
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
+   * @param numTablesToSearch This parameter allows the user to have control
+   *     over the number of hash tables to be searched. This allows 
+   *     the user to pick the number of tables it can afford for the time 
+   *     available without having to build hashing for every table size.
+   *     By default, this is set to zero in which case all tables are 
+   *     considered.
    */
   void Search(const size_t k,
               arma::Mat<size_t>& resultingNeighbors,
-              arma::mat& distances);
+              arma::mat& distances,
+              size_t numTablesToSearch = 0);
 
  private:
   /**
@@ -147,7 +164,8 @@
    *    multiple buckets of the second hash table.
    */
   void ReturnIndicesFromTable(const size_t queryIndex,
-                              arma::uvec& referenceIndices);
+                              arma::uvec& referenceIndices,
+                              size_t numTablesToSearch);
   /**
    * This is a helper function that computes the distance of the query to the
    * neighbor candidates and appropriately stores the best 'k' candidates
@@ -192,7 +210,7 @@
   arma::mat offsets; // should be numProj x numTables
 
   //! The hash width
-  const double hashWidth;
+  double hashWidth;
 
   //! The big prime representing the size of the second hash
   const size_t secondHashSize;

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-09 20:44:07 UTC (rev 14095)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-09 20:45:20 UTC (rev 14096)
@@ -20,7 +20,7 @@
           const arma::mat& querySet,
           const size_t numProj,
           const size_t numTables,
-          const double hashWidth,
+          const double hashWidthIn,
           const size_t secondHashSize,
           const size_t bucketSize,
           const MetricType metric) :
@@ -28,11 +28,26 @@
   querySet(querySet),
   numProj(numProj),
   numTables(numTables),
-  hashWidth(hashWidth),
+  hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
   bucketSize(bucketSize),
   metric(metric)
 {
+  if (hashWidth == 0.0) // the user has not provided any value
+  {
+    for (size_t i = 0; i < 25; i++)
+    {
+      size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
+      size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
+
+      hashWidth 
+        += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
+                                               referenceSet.unsafe_col(p2));
+    }
+
+    hashWidth /= 25;
+  } // computing a heuristic hashWidth from the data
+
   BuildHash();
 }
 
@@ -41,7 +56,7 @@
 LSHSearch(const arma::mat& referenceSet,
           const size_t numProj,
           const size_t numTables,
-          const double hashWidth,
+          const double hashWidthIn,
           const size_t secondHashSize,
           const size_t bucketSize,
           const MetricType metric) :
@@ -49,11 +64,26 @@
   querySet(referenceSet),
   numProj(numProj),
   numTables(numTables),
-  hashWidth(hashWidth),
+  hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
   bucketSize(bucketSize),
   metric(metric)
 {
+  if (hashWidth == 0.0) // the user has not provided any value
+  {
+    for (size_t i = 0; i < 25; i++)
+    {
+      size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
+      size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
+
+      hashWidth 
+        += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
+                                               referenceSet.unsafe_col(p2));
+    }
+
+    hashWidth /= 25;
+  } // computing a heuristic hashWidth from the data
+
   BuildHash();
 }
 
@@ -119,19 +149,30 @@
 template<typename SortPolicy, typename MetricType>
 void LSHSearch<SortPolicy, MetricType>::
 ReturnIndicesFromTable(const size_t queryIndex,
-                       arma::uvec& referenceIndices)
+                       arma::uvec& referenceIndices,
+                       size_t numTablesToSearch)
 {
-  // Hash the query in each of the 'numTables' hash tables using the
+  // deciding on the number of tables to look into.
+  if (numTablesToSearch == 0) // if no user input, search all
+    numTablesToSearch = numTables;
+
+
+  // sanity check to make sure that the existing number of tables is not 
+  // exceeded.
+  if (numTablesToSearch > numTables)
+    numTablesToSearch = numTables;
+
+  // Hash the query in each of the 'numTablesToSearch' hash tables using the
   // 'numProj' projections for each table.
-  // This gives us 'numTables' keys for the query where each key
+  // This gives us 'numTablesToSearch' keys for the query where each key
   // is a 'numProj' dimensional integer vector
   //
   // compute the projection of the query in each table
-  arma::mat allProjInTables(numProj, numTables);
-  for (size_t i = 0; i < numTables; i++)
+  arma::mat allProjInTables(numProj, numTablesToSearch);
+  for (size_t i = 0; i < numTablesToSearch; i++)
     allProjInTables.unsafe_col(i)
       = projections[i].t() * querySet.unsafe_col(queryIndex);
-  allProjInTables += offsets;
+  allProjInTables += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables /= hashWidth;
 
   // compute the hash value of each key of the query into a bucket of the
@@ -141,7 +182,7 @@
   for (size_t i = 0; i < hashVec.n_elem; i++)
     hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
 
-  assert(hashVec.n_elem == numTables);
+  assert(hashVec.n_elem == numTablesToSearch);
 
   // For all the buckets that the query is hashed into, sequentially
   // collect the indices in those buckets.
@@ -173,7 +214,8 @@
 void LSHSearch<SortPolicy, MetricType>::
 Search(const size_t k,
        arma::Mat<size_t>& resultingNeighbors,
-       arma::mat& distances)
+       arma::mat& distances,
+       size_t numTablesToSearch)
 {
   neighborPtr = &resultingNeighbors;
   distancePtr = &distances;
@@ -194,7 +236,7 @@
     // For hash every query into every hash tables and eventually
     // into the 'secondHashTable' to obtain the neighbor candidates
     arma::uvec refIndices;
-    ReturnIndicesFromTable(i, refIndices);
+    ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
 
     // Just an informative book-keeping for the number of neighbor candidates
     // returned on average




More information about the mlpack-svn mailing list