[mlpack-git] master: Refactor LSH. Don't used squared distances. (9c34c98)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue May 5 15:38:24 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6137e52d32c1338b28853afd059b67cf68a50270...9c34c98c8dad01b4ad659c37367a8b20287b6b5a

>---------------------------------------------------------------

commit 9c34c98c8dad01b4ad659c37367a8b20287b6b5a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue May 5 15:37:59 2015 -0400

    Refactor LSH.  Don't used squared distances.


>---------------------------------------------------------------

9c34c98c8dad01b4ad659c37367a8b20287b6b5a
 src/mlpack/methods/lsh/lsh_search.hpp      | 44 +++++++++----------
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 69 +++++++++++++++---------------
 2 files changed, 56 insertions(+), 57 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 83c6bd6..578c449 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -166,16 +166,23 @@ class LSHSearch
    * This is a helper function that computes the distance of the query to the
    * neighbor candidates and appropriately stores the best 'k' candidates
    *
+   * @param distances Matrix holding output distances.
+   * @param neighbors Matrix holding output neighbors.
    * @param queryIndex The index of the query in question
    * @param referenceIndex The index of the neighbor candidate in question
    */
-  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+  double BaseCase(arma::mat& distances,
+                  arma::Mat<size_t>& neighbors,
+                  const size_t queryIndex,
+                  const size_t referenceIndex);
 
   /**
    * This is a helper function that efficiently inserts better neighbor
    * candidates into an existing set of neighbor candidates. This function is
    * only called by the 'BaseCase' function.
    *
+   * @param distances Matrix holding output distances.
+   * @param neighbors Matrix holding output neighbors.
    * @param queryIndex This is the index of the query being processed currently
    * @param pos The position of the neighbor candidate in the current list of
    *    neighbor candidates.
@@ -183,42 +190,41 @@ class LSHSearch
    *    of the best 'k' candidates for the query in question.
    * @param distance The distance of the query to the neighbor candidate.
    */
-  void InsertNeighbor(const size_t queryIndex, const size_t pos,
-                      const size_t neighbor, const double distance);
+  void InsertNeighbor(arma::mat& distances,
+                      arma::Mat<size_t>& neighbors,
+                      const size_t queryIndex,
+                      const size_t pos,
+                      const size_t neighbor,
+                      const double distance);
 
   //! Reference dataset.
   const arma::mat& referenceSet;
-
   //! Query dataset (may not be given).
   const arma::mat& querySet;
 
-  //! The number of projections
+  //! The number of projections.
   const size_t numProj;
-
-  //! The number of hash tables
+  //! The number of hash tables.
   const size_t numTables;
 
-  //! The std::vector containing the projection matrix of each table
+  //! The std::vector containing the projection matrix of each table.
   std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
 
-  //! The list of the offset 'b' for each of the projection for each table
+  //! The list of the offsets 'b' for each of the projection for each table.
   arma::mat offsets; // should be numProj x numTables
 
-  //! The hash width
+  //! The hash width.
   double hashWidth;
 
-  //! The big prime representing the size of the second hash
+  //! The big prime representing the size of the second hash.
   const size_t secondHashSize;
 
-  //! The weights of the second hash
+  //! The weights of the second hash.
   arma::vec secondHashWeights;
 
-  //! The bucket size of the second hash
+  //! The bucket size of the second hash.
   const size_t bucketSize;
 
-  //! Instantiation of the metric.
-  metric::SquaredEuclideanDistance metric;
-
   //! The final hash table; should be (< secondHashSize) x bucketSize.
   arma::Mat<size_t> secondHashTable;
 
@@ -230,12 +236,6 @@ class LSHSearch
   //! corresponding to this value.  Should be secondHashSize.
   arma::Col<size_t> bucketRowInHashTable;
 
-  //! The pointer to the nearest neighbor distances.
-  arma::mat* distancePtr;
-
-  //! The pointer to the nearest neighbor indices.
-  arma::Mat<size_t>* neighborPtr;
-
   //! The number of distance evaluations.
   size_t distanceEvaluations;
 }; // class LSHSearch
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b6110a3..cd13557 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -39,8 +39,8 @@ LSHSearch(const arma::mat& referenceSet,
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
-                                             referenceSet.unsafe_col(p2)));
+      hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
+          referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
     }
 
     hashWidth /= 25;
@@ -76,8 +76,8 @@ LSHSearch(const arma::mat& referenceSet,
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
-                                             referenceSet.unsafe_col(p2)));
+      hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
+          referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
     }
 
     hashWidth /= 25;
@@ -89,53 +89,56 @@ LSHSearch(const arma::mat& referenceSet,
 }
 
 template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-InsertNeighbor(const size_t queryIndex,
-               const size_t pos,
-               const size_t neighbor,
-               const double distance)
+void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
+                                           arma::Mat<size_t>& neighbors,
+                                           const size_t queryIndex,
+                                           const size_t pos,
+                                           const size_t neighbor,
+                                           const double distance)
 {
   // We only memmove() if there is actually a need to shift something.
-  if (pos < (distancePtr->n_rows - 1))
+  if (pos < (distances.n_rows - 1))
   {
-    int len = (distancePtr->n_rows - 1) - pos;
-    memmove(distancePtr->colptr(queryIndex) + (pos + 1),
-        distancePtr->colptr(queryIndex) + pos,
+    const size_t len = (distances.n_rows - 1) - pos;
+    memmove(distances.colptr(queryIndex) + (pos + 1),
+        distances.colptr(queryIndex) + pos,
         sizeof(double) * len);
-    memmove(neighborPtr->colptr(queryIndex) + (pos + 1),
-        neighborPtr->colptr(queryIndex) + pos,
+    memmove(neighbors.colptr(queryIndex) + (pos + 1),
+        neighbors.colptr(queryIndex) + pos,
         sizeof(size_t) * len);
   }
 
   // Now put the new information in the right index.
-  (*distancePtr)(pos, queryIndex) = distance;
-  (*neighborPtr)(pos, queryIndex) = neighbor;
+  distances(pos, queryIndex) = distance;
+  neighbors(pos, queryIndex) = neighbor;
 }
 
 template<typename SortPolicy>
 inline force_inline
-double LSHSearch<SortPolicy>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
+                                       arma::Mat<size_t>& neighbors,
+                                       const size_t queryIndex,
+                                       const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
   if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
     return 0.0;
 
-  double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
-                                    referenceSet.unsafe_col(referenceIndex));
-  ++distanceEvaluations;
+  const double distance = metric::EuclideanDistance::Evaluate(
+      querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
-  arma::vec queryDist = distancePtr->unsafe_col(queryIndex);
-  arma::Col<size_t> queryIndices = neighborPtr->unsafe_col(queryIndex);
+  arma::vec queryDist = distances.unsafe_col(queryIndex);
+  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
   size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
       distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))
-    InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+    InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+        referenceIndex, distance);
 
   return distance;
 }
@@ -211,14 +214,11 @@ Search(const size_t k,
        arma::mat& distances,
        const size_t numTablesToSearch)
 {
-  neighborPtr = &resultingNeighbors;
-  distancePtr = &distances;
-
   // Set the size of the neighbor and distance matrices.
-  neighborPtr->set_size(k, querySet.n_cols);
-  distancePtr->set_size(k, querySet.n_cols);
-  distancePtr->fill(SortPolicy::WorstDistance());
-  neighborPtr->fill(referenceSet.n_cols);
+  resultingNeighbors.set_size(k, querySet.n_cols);
+  distances.set_size(k, querySet.n_cols);
+  distances.fill(SortPolicy::WorstDistance());
+  resultingNeighbors.fill(referenceSet.n_cols);
 
   size_t avgIndicesReturned = 0;
 
@@ -239,11 +239,12 @@ Search(const size_t k,
     // Sequentially go through all the candidates and save the best 'k'
     // candidates.
     for (size_t j = 0; j < refIndices.n_elem; j++)
-      BaseCase(i, (size_t) refIndices[j]);
+      BaseCase(distances, resultingNeighbors, i, (size_t) refIndices[j]);
   }
 
   Timer::Stop("computing_neighbors");
 
+  distanceEvaluations += avgIndicesReturned;
   avgIndicesReturned /= querySet.n_cols;
   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
       std::endl;
@@ -405,8 +406,6 @@ std::string LSHSearch<SortPolicy>::ToString() const
   convert << "  Number of Projections: " << numProj << std::endl;
   convert << "  Number of Tables: " << numTables << std::endl;
   convert << "  Hash Width: " << hashWidth << std::endl;
-  convert << "  Metric: " << std::endl;
-  convert << mlpack::util::Indent(metric.ToString(),2);
   return convert.str();
 }
 



More information about the mlpack-git mailing list