[mlpack-git] master: Make query set a parameter to Search(). (78cc694)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 9 16:30:41 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9bd2063f96de9430b387974e7ce7204a1e57a803...78cc694a4fd50a68a24f5ab9af7531873566b3ba

>---------------------------------------------------------------

commit 78cc694a4fd50a68a24f5ab9af7531873566b3ba
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 9 20:53:54 2015 +0000

    Make query set a parameter to Search().
    
    No need to hold it internally.


>---------------------------------------------------------------

78cc694a4fd50a68a24f5ab9af7531873566b3ba
 src/mlpack/methods/lsh/lsh_main.cpp        |  13 ++-
 src/mlpack/methods/lsh/lsh_search.hpp      |  97 ++++++++++-------
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 168 +++++++++++++++++------------
 src/mlpack/tests/lsh_test.cpp              |   4 +-
 4 files changed, 164 insertions(+), 118 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 9a5d37d..6fe63e3 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -128,18 +128,17 @@ int main(int argc, char *argv[])
 
   LSHSearch<>* allkann;
 
-  if (CLI::GetParam<string>("query_file") != "")
-    allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
-                              hashWidth, secondHashSize, bucketSize);
-  else
-    allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
-                              secondHashSize, bucketSize);
+  allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
+                            secondHashSize, bucketSize);
 
   Timer::Stop("hash_building");
 
   Log::Info << "Computing " << k << " distance approximate nearest neighbors "
       << endl;
-  allkann->Search(k, neighbors, distances);
+  if (CLI::HasParam("query_file"))
+    allkann->Search(queryData, k, neighbors, distances);
+  else
+    allkann->Search(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
 
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 578c449..c113b10 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -34,9 +34,9 @@ namespace mlpack {
 namespace neighbor {
 
 /**
- * The LSHSearch class -- This class builds a hash on the reference set
- * and uses this hash to compute the distance-approximate nearest-neighbors
- * of the given queries.
+ * The LSHSearch class; this class builds a hash on the reference set and uses
+ * this hash to compute the distance-approximate nearest-neighbors of the given
+ * queries.
  *
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
  */
@@ -49,8 +49,7 @@ class LSHSearch
    * reference set with 2-stable distributions. See the individual functions
    * performing the hashing for details on how the hashing is done.
    *
-   * @param referenceSet Set of reference points.
-   * @param querySet Set of query points.
+   * @param referenceSet Set of reference points and the set of queries.
    * @param numProj Number of projections in each hash table (anything between
    *     10-50 might be a decent choice).
    * @param numTables Total number of hash tables (anything between 10-20
@@ -66,7 +65,6 @@ class LSHSearch
    *     Default values are already provided here.
    */
   LSHSearch(const arma::mat& referenceSet,
-            const arma::mat& querySet,
             const size_t numProj,
             const size_t numTables,
             const double hashWidth = 0.0,
@@ -74,31 +72,29 @@ class LSHSearch
             const size_t bucketSize = 500);
 
   /**
-   * This function initializes the LSH class. It builds the hash on the
-   * reference set with 2-stable distributions. See the individual functions
-   * performing the hashing for details on how the hashing is done.
+   * Compute the nearest neighbors of the points in the given query set and
+   * store the output in the given matrices.  The matrices will be set to the
+   * size of n columns by k rows, where n is the number of points in the query
+   * dataset and k is the number of neighbors being searched for.
    *
-   * @param referenceSet Set of reference points and the set of queries.
-   * @param numProj Number of projections in each hash table (anything between
-   *     10-50 might be a decent choice).
-   * @param numTables Total number of hash tables (anything between 10-20
-   *     should suffice).
-   * @param hashWidth The width of hash for every table. If 0 (the default) is
-   *     provided, then the hash width is automatically obtained by computing
-   *     the average pairwise distance of 25 pairs.  This should be a reasonable
-   *     upper bound on the nearest-neighbor distance in general.
-   * @param secondHashSize The size of the second hash table. This should be a
-   *     large prime number.
-   * @param bucketSize The size of the bucket in the second hash table. This is
-   *     the maximum number of points that can be hashed into single bucket.
-   *     Default values are already provided here.
+   * @param querySet Set of query points.
+   * @param k Number of neighbors to search for.
+   * @param resultingNeighbors Matrix storing lists of neighbors for each query
+   *     point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *     point.
+   * @param numTablesToSearch This parameter allows the user to have control
+   *     over the number of hash tables to be searched. This allows
+   *     the user to pick the number of tables it can afford for the time
+   *     available without having to build hashing for every table size.
+   *     By default, this is set to zero in which case all tables are
+   *     considered.
    */
-  LSHSearch(const arma::mat& referenceSet,
-            const size_t numProj,
-            const size_t numTables,
-            const double hashWidth = 0.0,
-            const size_t secondHashSize = 99901,
-            const size_t bucketSize = 500);
+  void Search(const arma::mat& querySet,
+              const size_t k,
+              arma::Mat<size_t>& resultingNeighbors,
+              arma::mat& distances,
+              const size_t numTablesToSearch = 0);
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
@@ -153,28 +149,49 @@ class LSHSearch
    * hash table and all the points (if any) in those buckets are collected as
    * the potential neighbor candidates.
    *
-   * @param queryIndex The index of the query currently being processed.
+   * @param queryPoint The query point currently being processed.
    * @param referenceIndices The list of neighbor candidates obtained from
    *    hashing the query into all the hash tables and eventually into
    *    multiple buckets of the second hash table.
    */
-  void ReturnIndicesFromTable(const size_t queryIndex,
+  template<typename VecType>
+  void ReturnIndicesFromTable(const VecType& queryPoint,
                               arma::uvec& referenceIndices,
-                              size_t numTablesToSearch);
+                              size_t numTablesToSearch) const;
 
   /**
    * This is a helper function that computes the distance of the query to the
-   * neighbor candidates and appropriately stores the best 'k' candidates
+   * neighbor candidates and appropriately stores the best 'k' candidates.  This
+   * is specific to the monochromatic search case, where the query set is the
+   * reference set.
    *
-   * @param distances Matrix holding output distances.
+   * @param queryIndex The index of the query in question
+   * @param referenceIndex The index of the neighbor candidate in question
    * @param neighbors Matrix holding output neighbors.
+   * @param distances Matrix holding output distances.
+   */
+  void BaseCase(const size_t queryIndex,
+                const size_t referenceIndex,
+                arma::Mat<size_t>& neighbors,
+                arma::mat& distances) const;
+
+  /**
+   * This is a helper function that computes the distance of the query to the
+   * neighbor candidates and appropriately stores the best 'k' candidates.  This
+   * is specific to bichromatic search, where the query set is not the same as
+   * the reference set.
+   *
    * @param queryIndex The index of the query in question
    * @param referenceIndex The index of the neighbor candidate in question
+   * @param querySet Set of query points.
+   * @param neighbors Matrix holding output neighbors.
+   * @param distances Matrix holding output distances.
    */
-  double BaseCase(arma::mat& distances,
-                  arma::Mat<size_t>& neighbors,
-                  const size_t queryIndex,
-                  const size_t referenceIndex);
+  void BaseCase(const size_t queryIndex,
+                const size_t referenceIndex,
+                const arma::mat& querySet,
+                arma::Mat<size_t>& neighbors,
+                arma::mat& distances) const;
 
   /**
    * This is a helper function that efficiently inserts better neighbor
@@ -195,12 +212,10 @@ class LSHSearch
                       const size_t queryIndex,
                       const size_t pos,
                       const size_t neighbor,
-                      const double distance);
+                      const double distance) const;
 
   //! Reference dataset.
   const arma::mat& referenceSet;
-  //! Query dataset (may not be given).
-  const arma::mat& querySet;
 
   //! The number of projections.
   const size_t numProj;
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index cd13557..5760f89 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -16,51 +16,12 @@ namespace neighbor {
 template<typename SortPolicy>
 LSHSearch<SortPolicy>::
 LSHSearch(const arma::mat& referenceSet,
-          const arma::mat& querySet,
           const size_t numProj,
           const size_t numTables,
           const double hashWidthIn,
           const size_t secondHashSize,
           const size_t bucketSize) :
   referenceSet(referenceSet),
-  querySet(querySet),
-  numProj(numProj),
-  numTables(numTables),
-  hashWidth(hashWidthIn),
-  secondHashSize(secondHashSize),
-  bucketSize(bucketSize),
-  distanceEvaluations(0)
-{
-  if (hashWidth == 0.0) // The user has not provided any value.
-  {
-    // Compute a heuristic hash width from the data.
-    for (size_t i = 0; i < 25; i++)
-    {
-      size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
-      size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
-
-      hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
-          referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
-    }
-
-    hashWidth /= 25;
-  }
-
-  Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
-
-  BuildHash();
-}
-
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::
-LSHSearch(const arma::mat& referenceSet,
-          const size_t numProj,
-          const size_t numTables,
-          const double hashWidthIn,
-          const size_t secondHashSize,
-          const size_t bucketSize) :
-  referenceSet(referenceSet),
-  querySet(referenceSet),
   numProj(numProj),
   numTables(numTables),
   hashWidth(hashWidthIn),
@@ -94,7 +55,7 @@ void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
                                            const size_t queryIndex,
                                            const size_t pos,
                                            const size_t neighbor,
-                                           const double distance)
+                                           const double distance) const
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (distances.n_rows - 1))
@@ -113,20 +74,22 @@ void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
   neighbors(pos, queryIndex) = neighbor;
 }
 
+// Base case where the query set is the reference set.  (So, we can't return
+// ourselves as the nearest neighbor.)
 template<typename SortPolicy>
 inline force_inline
-double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
-                                       arma::Mat<size_t>& neighbors,
-                                       const size_t queryIndex,
-                                       const size_t referenceIndex)
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+                                     const size_t referenceIndex,
+                                     arma::Mat<size_t>& neighbors,
+                                     arma::mat& distances) const
 {
-  // If the datasets are the same, then this search is only using one dataset
-  // and we should not return identical points.
-  if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
-    return 0.0;
+  // If the points are the same, we can't continue.
+  if (queryIndex == referenceIndex)
+    return;
 
   const double distance = metric::EuclideanDistance::Evaluate(
-      querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+      referenceSet.unsafe_col(queryIndex),
+      referenceSet.unsafe_col(referenceIndex));
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
@@ -139,15 +102,39 @@ double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
   if (insertPosition != (size_t() - 1))
     InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
         referenceIndex, distance);
+}
 
-  return distance;
+// Base case for bichromatic search.
+template<typename SortPolicy>
+inline force_inline
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+                                     const size_t referenceIndex,
+                                     const arma::mat& querySet,
+                                     arma::Mat<size_t>& neighbors,
+                                     arma::mat& distances) const
+{
+  const double distance = metric::EuclideanDistance::Evaluate(
+      querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+
+  // If this distance is better than any of the current candidates, the
+  // SortDistance() function will give us the position to insert it into.
+  arma::vec queryDist = distances.unsafe_col(queryIndex);
+  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+  size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+      distance);
+
+  // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+  if (insertPosition != (size_t() - 1))
+    InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+        referenceIndex, distance);
 }
 
 template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-ReturnIndicesFromTable(const size_t queryIndex,
-                       arma::uvec& referenceIndices,
-                       size_t numTablesToSearch)
+template<typename VecType>
+void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
+    const VecType& queryPoint,
+    arma::uvec& referenceIndices,
+    size_t numTablesToSearch) const
 {
   // Decide on the number of tables to look into.
   if (numTablesToSearch == 0) // If no user input is given, search all.
@@ -166,10 +153,7 @@ ReturnIndicesFromTable(const size_t queryIndex,
   // Compute the projection of the query in each table.
   arma::mat allProjInTables(numProj, numTablesToSearch);
   for (size_t i = 0; i < numTablesToSearch; i++)
-  {
-    allProjInTables.unsafe_col(i) = projections[i].t() *
-        querySet.unsafe_col(queryIndex);
-  }
+    allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
   allProjInTables += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables /= hashWidth;
 
@@ -206,7 +190,58 @@ ReturnIndicesFromTable(const size_t queryIndex,
   referenceIndices = arma::find(refPointsConsidered > 0);
 }
 
+// Search for nearest neighbors in a given query set.
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
+                                   const size_t k,
+                                   arma::Mat<size_t>& resultingNeighbors,
+                                   arma::mat& distances,
+                                   const size_t numTablesToSearch)
+{
+  // Ensure the dimensionality of the query set is correct.
+  if (querySet.n_rows != referenceSet.n_rows)
+    Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
+        << querySet.n_rows << ") is not equal to the dimensionality the model "
+        << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
 
+  // Set the size of the neighbor and distance matrices.
+  resultingNeighbors.set_size(k, querySet.n_cols);
+  distances.set_size(k, querySet.n_cols);
+  distances.fill(SortPolicy::WorstDistance());
+  resultingNeighbors.fill(referenceSet.n_cols);
+
+  size_t avgIndicesReturned = 0;
+
+  Timer::Start("computing_neighbors");
+
+  // Go through every query point sequentially.
+  for (size_t i = 0; i < querySet.n_cols; i++)
+  {
+    // Hash every query into every hash table and eventually into the
+    // 'secondHashTable' to obtain the neighbor candidates.
+    arma::uvec refIndices;
+    ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch);
+
+    // An informative book-keeping for the number of neighbor candidates
+    // returned on average.
+    avgIndicesReturned += refIndices.n_elem;
+
+    // Sequentially go through all the candidates and save the best 'k'
+    // candidates.
+    for (size_t j = 0; j < refIndices.n_elem; j++)
+      BaseCase(i, (size_t) refIndices[j], querySet, resultingNeighbors,
+          distances);
+  }
+
+  Timer::Stop("computing_neighbors");
+
+  distanceEvaluations += avgIndicesReturned;
+  avgIndicesReturned /= querySet.n_cols;
+  Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
+      std::endl;
+}
+
+// Search for approximate neighbors of the reference set.
 template<typename SortPolicy>
 void LSHSearch<SortPolicy>::
 Search(const size_t k,
@@ -214,9 +249,9 @@ Search(const size_t k,
        arma::mat& distances,
        const size_t numTablesToSearch)
 {
-  // Set the size of the neighbor and distance matrices.
-  resultingNeighbors.set_size(k, querySet.n_cols);
-  distances.set_size(k, querySet.n_cols);
+  // This is monochromatic search; the query set is the reference set.
+  resultingNeighbors.set_size(k, referenceSet.n_cols);
+  distances.set_size(k, referenceSet.n_cols);
   distances.fill(SortPolicy::WorstDistance());
   resultingNeighbors.fill(referenceSet.n_cols);
 
@@ -225,12 +260,12 @@ Search(const size_t k,
   Timer::Start("computing_neighbors");
 
   // Go through every query point sequentially.
-  for (size_t i = 0; i < querySet.n_cols; i++)
+  for (size_t i = 0; i < referenceSet.n_cols; i++)
   {
     // Hash every query into every hash table and eventually into the
     // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
-    ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
+    ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch);
 
     // An informative book-keeping for the number of neighbor candidates
     // returned on average.
@@ -239,13 +274,13 @@ Search(const size_t k,
     // Sequentially go through all the candidates and save the best 'k'
     // candidates.
     for (size_t j = 0; j < refIndices.n_elem; j++)
-      BaseCase(distances, resultingNeighbors, i, (size_t) refIndices[j]);
+      BaseCase(i, (size_t) refIndices[j], resultingNeighbors, distances);
   }
 
   Timer::Stop("computing_neighbors");
 
   distanceEvaluations += avgIndicesReturned;
-  avgIndicesReturned /= querySet.n_cols;
+  avgIndicesReturned /= referenceSet.n_cols;
   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
       std::endl;
 }
@@ -400,9 +435,6 @@ std::string LSHSearch<SortPolicy>::ToString() const
   convert << "LSHSearch [" << this << "]" << std::endl;
   convert << "  Reference Set: " << referenceSet.n_rows << "x" ;
   convert <<  referenceSet.n_cols << std::endl;
-  if (&referenceSet != &querySet)
-    convert << "  QuerySet: " << querySet.n_rows << "x" << querySet.n_cols
-        << std::endl;
   convert << "  Number of Projections: " << numProj << std::endl;
   convert << "  Number of Tables: " << numTables << std::endl;
   convert << "  Hash Width: " << hashWidth << std::endl;
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 52afc25..341da69 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -58,7 +58,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   //    projMat.randn(2, 3)
   //    COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
   //    COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
-  LSHSearch<> lsh_test(rdata, qdata, 3, 2, hashWidth, 11, 3);
+  LSHSearch<> lsh_test(rdata, 3, 2, hashWidth, 11, 3);
 //   LSHSearch<> lsh_test(rdata, qdata, 3, 2, 0.0, 11, 3);
 
   // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  lsh_test.Search(2, neighbors, distances);
+  lsh_test.Search(qdata, 2, neighbors, distances);
 
   // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
   // should hash the query 0 into the following buckets:



More information about the mlpack-git mailing list