[mlpack-svn] r14136 - mlpack/trunk/src/mlpack/methods/lsh

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jan 18 17:00:49 EST 2013


Author: rcurtin
Date: 2013-01-18 17:00:49 -0500 (Fri, 18 Jan 2013)
New Revision: 14136

Modified:
   mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
Clean up LSH code and add random seed parameter.


Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-18 22:00:49 UTC (rev 14136)
@@ -2,8 +2,8 @@
  * @file lsh_main.cpp
  * @author Parikshit Ram
  *
- * This file computes the approximate nearest-neighbors using 2-stable 
- * Locality-sensitive Hashing. 
+ * This file computes the approximate nearest-neighbors using 2-stable
+ * Locality-sensitive Hashing.
  */
 #include <time.h>
 
@@ -22,48 +22,59 @@
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
-    "This program will calculate the k approximate-nearest-neighbors "
-    "of a set of points. You may specify a separate set of reference "
-    "points and query points, or just a reference set which will be "
-    "used as both the reference and query set. "
+    "This program will calculate the k approximate-nearest-neighbors of a set "
+    "of points using locality-sensitive hashing. You may specify a separate set"
+    " of reference points and query points, or just a reference set which will "
+    "be used as both the reference and query set. "
     "\n\n"
-    "For example, the following will return 5 neighbors from the "
-    "data for each point in 'input.csv' "
-    "and store the distances in 'distances.csv' and the neighbors in the "
-    "file 'neighbors.csv':"
+    "For example, the following will return 5 neighbors from the data for each "
+    "point in 'input.csv' and store the distances in 'distances.csv' and the "
+    "neighbors in the file 'neighbors.csv':"
     "\n\n"
-    "$ ./lsh/lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
+    "$ lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
     "\n\n"
     "The output files are organized such that row i and column j in the "
     "neighbors output file corresponds to the index of the point in the "
     "reference set which is the i'th nearest neighbor from the point in the "
     "query set with index j.  Row i and column j in the distances output file "
-    "corresponds to the distance between those two points.");
+    "corresponds to the distance between those two points."
+    "\n\n"
+    "Because this is approximate-nearest-neighbors search, results may be "
+    "different from run to run.  Thus, the --seed option can be specified to "
+    "set the random seed.");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.", "r");
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+    "r");
 PARAM_STRING("distances_file", "File to output distances into.", "d", "");
 PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
 PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+
 PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
 
-PARAM_INT("num_projections", "The number of hash functions for each table", 
-          "K", 10);
-PARAM_INT("num_tables", "The number of hash tables to be used.", "L", 30);
-PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing "
-             "in the LSH preprocessing. By default, the LSH class "
-             "automatically estimates a hash width for its use.", "H", 0.0);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", 
-          "M", 99901);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", 
-          "B", 500);
+PARAM_INT("projections", "The number of hash functions for each table", "K",
+    10);
+PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
+PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
+    "LSH preprocessing. By default, the LSH class automatically estimates a "
+    "hash width for its use.", "H", 0.0);
+PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
+    99901);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
+    500);
+PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 
 int main(int argc, char *argv[])
 {
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
-  math::RandomSeed(time(NULL)); 
 
+  if (CLI::GetParam<int>("seed") != 0)
+    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+  else
+    math::RandomSeed((size_t) time(NULL));
+
   // Get all the parameters.
   string referenceFile = CLI::GetParam<string>("reference_file");
   string distancesFile = CLI::GetParam<string>("distances_file");
@@ -89,15 +100,11 @@
     Log::Fatal << referenceData.n_cols << ")." << endl;
   }
 
+  // Pick up the LSH-specific parameters.
+  const size_t numProj = CLI::GetParam<int>("num_projections");
+  const size_t numTables = CLI::GetParam<int>("num_tables");
+  const double hashWidth = CLI::GetParam<double>("hash_width");
 
-  // Pick up the 'K' and the 'L' parameter for LSH
-  size_t numProj = CLI::GetParam<int>("num_projections");
-  size_t numTables = CLI::GetParam<int>("num_tables");
-  
-  // Compute the 'hash_width' parameter from LSH
-  double hashWidth = CLI::GetParam<double>("hash_width");
-
-
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
@@ -111,11 +118,11 @@
   }
 
   if (hashWidth == 0.0)
-    Log::Info << "LSH with " << numProj << " projections(K) and " << 
-      numTables << " tables(L) with default hash width." << endl;
+    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+        numTables << " tables (L) with default hash width." << endl;
   else
-    Log::Info << "LSH with " << numProj << " projections(K) and " << 
-      numTables << " tables(L) with hash width(r): " << hashWidth << endl;
+    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+        numTables << " tables (L) with hash width(r): " << hashWidth << endl;
 
   Timer::Start("hash_building");
 
@@ -129,21 +136,19 @@
                               secondHashSize, bucketSize);
 
   Timer::Stop("hash_building");
-  
-  Log::Info << "Computing " << k << " distance approx. nearest neighbors " << 
-    endl;
+
+  Log::Info << "Computing " << k << " distance approximate nearest neighbors "
+      << endl;
   allkann->Search(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
 
   // Save output.
-  if (distancesFile != "") 
+  if (distancesFile != "")
     data::Save(distancesFile, distances);
 
   if (neighborsFile != "")
     data::Save(neighborsFile, neighbors);
 
   delete allkann;
-
-  return 0;
 }

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-18 22:00:49 UTC (rev 14136)
@@ -34,8 +34,8 @@
 namespace neighbor {
 
 /**
- * The LSHSearch class -- This class builds a hash on the reference set 
- * and uses this hash to compute the distance-approximate nearest-neighbors 
+ * The LSHSearch class -- This class builds a hash on the reference set
+ * and uses this hash to compute the distance-approximate nearest-neighbors
  * of the given queries.
  *
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
@@ -47,21 +47,20 @@
 {
  public:
   /**
-   * This function initializes the LSH class. It builds the hash on the 
-   * reference set with 2-stable distributions. See the individual functions 
+   * This function initializes the LSH class. It builds the hash on the
+   * reference set with 2-stable distributions. See the individual functions
    * performing the hashing for details on how the hashing is done.
    *
    * @param referenceSet Set of reference points.
    * @param querySet Set of query points.
    * @param numProj Number of projections in each hash table (anything between
    *     10-50 might be a decent choice).
-   * @param numTables Total number of hash tables (anything between 10-20 
+   * @param numTables Total number of hash tables (anything between 10-20
    *     should suffice).
-   * @param hashWidth The width of hash for every table. If the user does not 
-   *     provide a value then the class automatically obtains a hash width
-   *     by computing the average pairwise distance of 25 pairs. This should 
-   *     be a reasonable upper bound on the nearest-neighbor distance 
-   *     in general.
+   * @param hashWidth The width of hash for every table. If 0 (the default) is
+   *     provided, then the hash width is automatically obtained by computing
+   *     the average pairwise distance of 25 pairs.  This should be a reasonable
+   *     upper bound on the nearest-neighbor distance in general.
    * @param secondHashSize The size of the second hash table. This should be a
    *     large prime number.
    * @param bucketSize The size of the bucket in the second hash table. This is
@@ -79,20 +78,19 @@
             const MetricType metric = MetricType());
 
   /**
-   * This function initializes the LSH class. It builds the hash on the 
-   * reference set with 2-stable distributions. See the individual functions 
+   * This function initializes the LSH class. It builds the hash on the
+   * reference set with 2-stable distributions. See the individual functions
    * performing the hashing for details on how the hashing is done.
    *
    * @param referenceSet Set of reference points and the set of queries.
    * @param numProj Number of projections in each hash table (anything between
    *     10-50 might be a decent choice).
-   * @param numTables Total number of hash tables (anything between 10-20 
+   * @param numTables Total number of hash tables (anything between 10-20
    *     should suffice).
-   * @param hashWidth The width of hash for every table. If the user does not 
-   *     provide a value then the class automatically obtains a hash width
-   *     by computing the average pairwise distance of 25 pairs. This should 
-   *     be a reasonable upper bound on the nearest-neighbor distance 
-   *     in general.
+   * @param hashWidth The width of hash for every table. If 0 (the default) is
+   *     provided, then the hash width is automatically obtained by computing
+   *     the average pairwise distance of 25 pairs.  This should be a reasonable
+   *     upper bound on the nearest-neighbor distance in general.
    * @param secondHashSize The size of the second hash table. This should be a
    *     large prime number.
    * @param bucketSize The size of the bucket in the second hash table. This is
@@ -107,15 +105,10 @@
             const size_t secondHashSize = 99901,
             const size_t bucketSize = 500,
             const MetricType metric = MetricType());
-  /**
-   * Delete the LSHSearch object. The tree is the only member we are
-   * responsible for deleting.  The others will take care of themselves.
-   */
-  ~LSHSearch();
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
-   * The matrices will be set to the size of n columns by k rows, where n is 
+   * The matrices will be set to the size of n columns by k rows, where n is
    * the number of points in the query dataset and k is the number of neighbors
    * being searched for.
    *
@@ -125,16 +118,16 @@
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    * @param numTablesToSearch This parameter allows the user to have control
-   *     over the number of hash tables to be searched. This allows 
-   *     the user to pick the number of tables it can afford for the time 
+   *     over the number of hash tables to be searched. This allows
+   *     the user to pick the number of tables it can afford for the time
    *     available without having to build hashing for every table size.
-   *     By default, this is set to zero in which case all tables are 
+   *     By default, this is set to zero in which case all tables are
    *     considered.
    */
   void Search(const size_t k,
               arma::Mat<size_t>& resultingNeighbors,
               arma::mat& distances,
-              size_t numTablesToSearch = 0);
+              const size_t numTablesToSearch = 0);
 
  private:
   /**
@@ -166,6 +159,7 @@
   void ReturnIndicesFromTable(const size_t queryIndex,
                               arma::uvec& referenceIndices,
                               size_t numTablesToSearch);
+
   /**
    * This is a helper function that computes the distance of the query to the
    * neighbor candidates and appropriately stores the best 'k' candidates
@@ -224,23 +218,22 @@
   //! Instantiation of the metric.
   MetricType metric;
 
-  //! The final hash table
-  arma::Mat<size_t> secondHashTable; // should be (< secondHashSize) x bucketSize
+  //! The final hash table; should be (< secondHashSize) x bucketSize.
+  arma::Mat<size_t> secondHashTable;
 
-  //! The number of elements present in each hash bucket
-  arma::Col<size_t> bucketContentSize; // should be secondHashSize
+  //! The number of elements present in each hash bucket; should be
+  //! secondHashSize.
+  arma::Col<size_t> bucketContentSize;
 
   //! For a particular hash value, points to the row in secondHashTable
-  //! corresponding to this value
-  arma::Col<size_t> bucketRowInHashTable; // should be secondHashSize
+  //! corresponding to this value.  Should be secondHashSize.
+  arma::Col<size_t> bucketRowInHashTable;
 
-  //! The pointer to the nearest neighbor distance
+  //! The pointer to the nearest neighbor distances.
   arma::mat* distancePtr;
 
-  //! The pointer to the nearest neighbor indices
+  //! The pointer to the nearest neighbor indices.
   arma::Mat<size_t>* neighborPtr;
-
-
 }; // class LSHSearch
 
 }; // namespace neighbor

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-18 21:46:28 UTC (rev 14135)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-18 22:00:49 UTC (rev 14136)
@@ -33,20 +33,20 @@
   bucketSize(bucketSize),
   metric(metric)
 {
-  if (hashWidth == 0.0) // the user has not provided any value
+  if (hashWidth == 0.0) // The user has not provided any value.
   {
+    // Compute a heuristic hash width from the data.
     for (size_t i = 0; i < 25; i++)
     {
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth 
-        += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
-                                               referenceSet.unsafe_col(p2));
+      hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
+                                        referenceSet.unsafe_col(p2));
     }
 
     hashWidth /= 25;
-  } // computing a heuristic hashWidth from the data
+  }
 
   BuildHash();
 }
@@ -69,32 +69,26 @@
   bucketSize(bucketSize),
   metric(metric)
 {
-  if (hashWidth == 0.0) // the user has not provided any value
+  if (hashWidth == 0.0) // The user has not provided any value.
   {
+    // Compute a heuristic hash width from the data.
     for (size_t i = 0; i < 25; i++)
     {
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth 
-        += metric::EuclideanDistance::Evaluate(referenceSet.unsafe_col(p1),
-                                               referenceSet.unsafe_col(p2));
+      hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
+                                        referenceSet.unsafe_col(p2));
     }
 
     hashWidth /= 25;
-  } // computing a heuristic hashWidth from the data
+  }
 
   BuildHash();
 }
 
 
 template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
-~LSHSearch()
-{ }
-
-
-template<typename SortPolicy, typename MetricType>
 void LSHSearch<SortPolicy, MetricType>::
 InsertNeighbor(const size_t queryIndex,
                const size_t pos,
@@ -152,50 +146,51 @@
                        arma::uvec& referenceIndices,
                        size_t numTablesToSearch)
 {
-  // deciding on the number of tables to look into.
-  if (numTablesToSearch == 0) // if no user input, search all
+  // Decide on the number of tables to look into.
+  if (numTablesToSearch == 0) // If no user input is given, search all.
     numTablesToSearch = numTables;
 
-
-  // sanity check to make sure that the existing number of tables is not 
+  // Sanity check to make sure that the existing number of tables is not
   // exceeded.
   if (numTablesToSearch > numTables)
     numTablesToSearch = numTables;
 
   // Hash the query in each of the 'numTablesToSearch' hash tables using the
-  // 'numProj' projections for each table.
-  // This gives us 'numTablesToSearch' keys for the query where each key
-  // is a 'numProj' dimensional integer vector
-  //
-  // compute the projection of the query in each table
+  // 'numProj' projections for each table. This gives us 'numTablesToSearch'
+  // keys for the query where each key is a 'numProj' dimensional integer
+  // vector.
+
+  // Compute the projection of the query in each table.
   arma::mat allProjInTables(numProj, numTablesToSearch);
   for (size_t i = 0; i < numTablesToSearch; i++)
-    allProjInTables.unsafe_col(i)
-      = projections[i].t() * querySet.unsafe_col(queryIndex);
+  {
+    allProjInTables.unsafe_col(i) = projections[i].t() *
+        querySet.unsafe_col(queryIndex);
+  }
   allProjInTables += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables /= hashWidth;
 
-  // compute the hash value of each key of the query into a bucket of the
+  // Compute the hash value of each key of the query into a bucket of the
   // 'secondHashTable' using the 'secondHashWeights'.
   arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
 
   for (size_t i = 0; i < hashVec.n_elem; i++)
-    hashVec[i] = (double)((size_t) hashVec[i] % secondHashSize);
+    hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
 
-  assert(hashVec.n_elem == numTablesToSearch);
+  Log::Assert(hashVec.n_elem == numTablesToSearch);
 
   // For all the buckets that the query is hashed into, sequentially
   // collect the indices in those buckets.
   arma::Col<size_t> refPointsConsidered;
   refPointsConsidered.zeros(referenceSet.n_cols);
 
-  for (size_t i = 0; i < hashVec.n_elem; i++)
+  for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
   {
     size_t hashInd = (size_t) hashVec[i];
 
     if (bucketContentSize[hashInd] > 0)
     {
-      // Pick the indices in the bucket corresponding to 'hashInd'
+      // Pick the indices in the bucket corresponding to 'hashInd'.
       size_t tableRow = bucketRowInHashTable[hashInd];
       assert(tableRow < secondHashSize);
       assert(tableRow < secondHashTable.n_rows);
@@ -203,10 +198,9 @@
       for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
         refPointsConsidered[secondHashTable(tableRow, j)]++;
     }
-  } // for all tables
+  }
 
   referenceIndices = arma::find(refPointsConsidered > 0);
-  return;
 }
 
 
@@ -215,7 +209,7 @@
 Search(const size_t k,
        arma::Mat<size_t>& resultingNeighbors,
        arma::mat& distances,
-       size_t numTablesToSearch)
+       const size_t numTablesToSearch)
 {
   neighborPtr = &resultingNeighbors;
   distancePtr = &distances;
@@ -230,20 +224,20 @@
 
   Timer::Start("computing_neighbors");
 
-  // go through every query point sequentially
+  // Go through every query point sequentially.
   for (size_t i = 0; i < querySet.n_cols; i++)
   {
-    // For hash every query into every hash tables and eventually
-    // into the 'secondHashTable' to obtain the neighbor candidates
+    // Hash every query into every hash table and eventually into the
+    // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
     ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
 
-    // Just an informative book-keeping for the number of neighbor candidates
-    // returned on average
+    // An informative book-keeping for the number of neighbor candidates
+    // returned on average.
     avgIndicesReturned += refIndices.n_elem;
 
     // Sequentially go through all the candidates and save the best 'k'
-    // candidates
+    // candidates.
     for (size_t j = 0; j < refIndices.n_elem; j++)
       BaseCase(i, (size_t) refIndices[j]);
   }
@@ -252,9 +246,7 @@
 
   avgIndicesReturned /= querySet.n_cols;
   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
-    std::endl;
-
-  return;
+      std::endl;
 }
 
 template<typename SortPolicy, typename MetricType>
@@ -276,148 +268,121 @@
   // given by <key, 'secondHashWeights'> % 'secondHashSize'
   // and the corresponding point ID is put into that bucket.
 
-  //////////////////////////////////////////
-  // Step I: Preparing the second level hash
-  ///////////////////////////////////////////
+  // Step I: Prepare the second level hash.
 
-  // obtain the weights for the second hash
-  secondHashWeights = arma::floor(arma::randu(numProj)
-                                  * (double) secondHashSize);
+  // Obtain the weights for the second hash.
+  secondHashWeights = arma::floor(arma::randu(numProj) *
+                                  (double) secondHashSize);
 
   // The 'secondHashTable' is initially an empty matrix of size
   // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
   // as points land in them allows us to shrink the size of the
   // 'secondHashTable' at the end of the hashing.
 
-  // Start filling up the second hash table
+  // Fill the second hash table n = referenceSet.n_cols.  This is because no
+  // point has index 'n' so the presence of this in the bucket denotes that
+  // there are no more points in this bucket.
   secondHashTable.set_size(secondHashSize, bucketSize);
-
-  // Fill the second hash table n = referenceSet.n_cols
-  // This is because no point has index 'n' so the presence of
-  // this in the bucket denotes that there are no more points
-  // in this bucket.
   secondHashTable.fill(referenceSet.n_cols);
 
-  // Keeping track of the size of each bucket in the hash.
-  // At the end of hashing most buckets will be empty.
+  // Keep track of the size of each bucket in the hash.  At the end of hashing
+  // most buckets will be empty.
   bucketContentSize.zeros(secondHashSize);
 
-  // Instead of putting the points in the row corresponding to
-  // the bucket, we chose the next empty row and keep track of
-  // the row in which the bucket lies. This allows us to
-  // stack together and slice out the empty buckets at the
-  // end of the hashing.
+  // Instead of putting the points in the row corresponding to the bucket, we
+  // chose the next empty row and keep track of the row in which the bucket
+  // lies. This allows us to stack together and slice out the empty buckets at
+  // the end of the hashing.
   bucketRowInHashTable.set_size(secondHashSize);
   bucketRowInHashTable.fill(secondHashSize);
 
-  // keeping track of number of non-empty rows in the 'secondHashTable'
+  // Keep track of number of non-empty rows in the 'secondHashTable'.
   size_t numRowsInTable = 0;
 
-
-  /////////////////////////////////////////////////////////
-  // Step II: The offsets for all projections in all tables
-  /////////////////////////////////////////////////////////
-
+  // Step II: The offsets for all projections in all tables.
   // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
-  // as randu(numProj, numTables) * hashWidth
+  // as randu(numProj, numTables) * hashWidth.
   offsets.randu(numProj, numTables);
   offsets *= hashWidth;
 
-  /////////////////////////////////////////////////////////////////
-  // Step III: Creating each hash table in the first level hash
-  // one by one and putting them directly into the 'secondHashTable'
-  // for memory efficiency.
-  /////////////////////////////////////////////////////////////////
-
-  for(size_t i = 0; i < numTables; i++)
+  // Step III: Create each hash table in the first level hash one by one and
+  // putting them directly into the 'secondHashTable' for memory efficiency.
+  for (size_t i = 0; i < numTables; i++)
   {
-    //////////////////////////////////////////////////////////////
-    // Step IV: Obtaining the 'numProj' projections for each table
-    //////////////////////////////////////////////////////////////
-    //
+    // Step IV: Obtain the 'numProj' projections for each table.
+
     // For L2 metric, 2-stable distributions are used, and
     // the normal Z ~ N(0, 1) is a 2-stable distribution.
     arma::mat projMat;
     projMat.randn(referenceSet.n_rows, numProj);
 
-    // save the projection matrix for querying
+    // Save the projection matrix for querying.
     projections.push_back(projMat);
 
-    ///////////////////////////////////////////////////////////////
-    // Step V: create the 'numProj'-dimensional key for each point
-    // in each table.
-    //////////////////////////////////////////////////////////////
+    // Step V: create the 'numProj'-dimensional key for each point in each
+    // table.
 
-    // The following set of lines performs the task of
-    // hashing each point to a 'numProj'-dimensional integer key.
-    // Hence you get a ('numProj' x 'referenceSet.n_cols') key matrix
+    // The following code performs the task of hashing each point to a
+    // 'numProj'-dimensional integer key.  Hence you get a ('numProj' x
+    // 'referenceSet.n_cols') key matrix.
     //
-    // For a single table, let the 'numProj' projections be denoted
-    // by 'proj_i' and the corresponding offset be 'offset_i'.
-    // Then the key of a single point is obtained as:
+    // For a single table, let the 'numProj' projections be denoted by 'proj_i'
+    // and the corresponding offset be 'offset_i'.  Then the key of a single
+    // point is obtained as:
     // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
-    arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i),
-                                       1, referenceSet.n_cols);
+    arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
+                                       referenceSet.n_cols);
     arma::mat hashMat = projMat.t() * referenceSet;
     hashMat += offsetMat;
     hashMat /= hashWidth;
 
-    ////////////////////////////////////////////////////////////
-    // Step VI: Putting the points in the 'secondHashTable' by
-    // hashing the key.
-    ///////////////////////////////////////////////////////////
-
-    // Now we hash every key, point ID to its corresponding bucket
+    // Step VI: Putting the points in the 'secondHashTable' by hashing the key.
+    // Now we hash every key, point ID to its corresponding bucket.
     arma::rowvec secondHashVec = secondHashWeights.t()
       * arma::floor(hashMat);
 
-    // This gives us the bucket for the corresponding point ID
+    // This gives us the bucket for the corresponding point ID.
     for (size_t j = 0; j < secondHashVec.n_elem; j++)
       secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
 
-    assert(secondHashVec.n_elem == referenceSet.n_cols);
+    Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
 
-    // Inserting the point in the corresponding row to its bucket
-    // in the 'secondHashTable'.
+    // Insert the point in the corresponding row to its bucket in the
+    // 'secondHashTable'.
     for (size_t j = 0; j < secondHashVec.n_elem; j++)
     {
-      // This is the bucket number
+      // This is the bucket number.
       size_t hashInd = (size_t) secondHashVec[j];
-      // The point ID is 'j'
+      // The point ID is 'j'.
 
-      // If this is currently an empty bucket, start a new row
-      // keep track of which row corresponds to the bucket.
+      // If this is currently an empty bucket, start a new row keep track of
+      // which row corresponds to the bucket.
       if (bucketContentSize[hashInd] == 0)
       {
-        // start a new row for hash
+        // Start a new row for hash.
         bucketRowInHashTable[hashInd] = numRowsInTable;
         secondHashTable(numRowsInTable, 0) = j;
 
         numRowsInTable++;
       }
-      // If bucket already present in the 'secondHashTable', find
-      // the corresponding row and insert the point ID in this row
-      // unless the bucket is full, in which case, do nothing.
+
       else
       {
-        // if bucket not full, insert point here
+        // If bucket is already present in the 'secondHashTable', find the
+        // corresponding row and insert the point ID in this row unless the
+        // bucket is full, in which case, do nothing.
         if (bucketContentSize[hashInd] < bucketSize)
           secondHashTable(bucketRowInHashTable[hashInd],
                           bucketContentSize[hashInd]) = j;
-        // else just ignore as suggested
       }
 
-      // increment the count of the points in this bucket
+      // Increment the count of the points in this bucket.
       if (bucketContentSize[hashInd] < bucketSize)
         bucketContentSize[hashInd]++;
-    } // loop over all points in the reference set
-  } // loop over tables
+    } // Loop over all points in the reference set.
+  } // Loop over tables.
 
-
-  /////////////////////////////////////////////////
-  // Step VII: Condensing the 'secondHashTable'
-  /////////////////////////////////////////////////
-
+  // Step VII: Condensing the 'secondHashTable'.
   size_t maxBucketSize = 0;
   for (size_t i = 0; i < bucketContentSize.n_elem; i++)
     if (bucketContentSize[i] > maxBucketSize)
@@ -426,9 +391,6 @@
   Log::Info << "Final hash table size: (" << numRowsInTable << " x "
             << maxBucketSize << ")" << std::endl;
   secondHashTable.resize(numRowsInTable, maxBucketSize);
-
-  return;
 }
 
-
 #endif




More information about the mlpack-svn mailing list