[mlpack-git] master: Use a pointer to the reference set, so it can be changed. (9b75b38)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:22 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d

>---------------------------------------------------------------

commit 9b75b383d4a2f86666871a36ae63e5c3371f70c7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 9 22:48:18 2015 +0000

    Use a pointer to the reference set, so it can be changed.


>---------------------------------------------------------------

9b75b383d4a2f86666871a36ae63e5c3371f70c7
 src/mlpack/methods/lsh/lsh_search.hpp      |  9 ++++-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 56 ++++++++++++++++++------------
 2 files changed, 41 insertions(+), 24 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index c113b10..e07559c 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -72,6 +72,11 @@ class LSHSearch
             const size_t bucketSize = 500);
 
   /**
+   * Clean memory.
+   */
+  ~LSHSearch();
+
+  /**
    * Compute the nearest neighbors of the points in the given query set and
    * store the output in the given matrices.  The matrices will be set to the
    * size of n columns by k rows, where n is the number of points in the query
@@ -215,7 +220,9 @@ class LSHSearch
                       const double distance) const;
 
   //! Reference dataset.
-  const arma::mat& referenceSet;
+  const arma::mat* referenceSet;
+  //! If true, we own the reference set.
+  bool ownsSet;
 
   //! The number of projections.
   const size_t numProj;
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 5760f89..2463751 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -21,7 +21,8 @@ LSHSearch(const arma::mat& referenceSet,
           const double hashWidthIn,
           const size_t secondHashSize,
           const size_t bucketSize) :
-  referenceSet(referenceSet),
+  referenceSet(&referenceSet),
+  ownsSet(false),
   numProj(numProj),
   numTables(numTables),
   hashWidth(hashWidthIn),
@@ -49,6 +50,14 @@ LSHSearch(const arma::mat& referenceSet,
   BuildHash();
 }
 
+// Destructor.
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::~LSHSearch()
+{
+  if (ownsSet)
+    delete referenceSet;
+}
+
 template<typename SortPolicy>
 void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
                                            arma::Mat<size_t>& neighbors,
@@ -88,8 +97,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
     return;
 
   const double distance = metric::EuclideanDistance::Evaluate(
-      referenceSet.unsafe_col(queryIndex),
-      referenceSet.unsafe_col(referenceIndex));
+      referenceSet->unsafe_col(queryIndex),
+      referenceSet->unsafe_col(referenceIndex));
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
@@ -114,7 +123,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
                                      arma::mat& distances) const
 {
   const double distance = metric::EuclideanDistance::Evaluate(
-      querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+      querySet.unsafe_col(queryIndex),
+      referenceSet->unsafe_col(referenceIndex));
 
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
@@ -169,7 +179,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
   // For all the buckets that the query is hashed into, sequentially
   // collect the indices in those buckets.
   arma::Col<size_t> refPointsConsidered;
-  refPointsConsidered.zeros(referenceSet.n_cols);
+  refPointsConsidered.zeros(referenceSet->n_cols);
 
   for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
   {
@@ -199,16 +209,16 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
                                    const size_t numTablesToSearch)
 {
   // Ensure the dimensionality of the query set is correct.
-  if (querySet.n_rows != referenceSet.n_rows)
+  if (querySet.n_rows != referenceSet->n_rows)
     Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
         << querySet.n_rows << ") is not equal to the dimensionality the model "
-        << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
+        << "was trained on (" << referenceSet->n_rows << ")!" << std::endl;
 
   // Set the size of the neighbor and distance matrices.
   resultingNeighbors.set_size(k, querySet.n_cols);
   distances.set_size(k, querySet.n_cols);
   distances.fill(SortPolicy::WorstDistance());
-  resultingNeighbors.fill(referenceSet.n_cols);
+  resultingNeighbors.fill(referenceSet->n_cols);
 
   size_t avgIndicesReturned = 0;
 
@@ -250,22 +260,22 @@ Search(const size_t k,
        const size_t numTablesToSearch)
 {
   // This is monochromatic search; the query set is the reference set.
-  resultingNeighbors.set_size(k, referenceSet.n_cols);
-  distances.set_size(k, referenceSet.n_cols);
+  resultingNeighbors.set_size(k, referenceSet->n_cols);
+  distances.set_size(k, referenceSet->n_cols);
   distances.fill(SortPolicy::WorstDistance());
-  resultingNeighbors.fill(referenceSet.n_cols);
+  resultingNeighbors.fill(referenceSet->n_cols);
 
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
 
   // Go through every query point sequentially.
-  for (size_t i = 0; i < referenceSet.n_cols; i++)
+  for (size_t i = 0; i < referenceSet->n_cols; i++)
   {
     // Hash every query into every hash table and eventually into the
     // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
-    ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch);
+    ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
 
     // An informative book-keeping for the number of neighbor candidates
     // returned on average.
@@ -280,7 +290,7 @@ Search(const size_t k,
   Timer::Stop("computing_neighbors");
 
   distanceEvaluations += avgIndicesReturned;
-  avgIndicesReturned /= referenceSet.n_cols;
+  avgIndicesReturned /= referenceSet->n_cols;
   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
       std::endl;
 }
@@ -318,7 +328,7 @@ void LSHSearch<SortPolicy>::BuildHash()
   // point has index 'n' so the presence of this in the bucket denotes that
   // there are no more points in this bucket.
   secondHashTable.set_size(secondHashSize, bucketSize);
-  secondHashTable.fill(referenceSet.n_cols);
+  secondHashTable.fill(referenceSet->n_cols);
 
   // Keep track of the size of each bucket in the hash.  At the end of hashing
   // most buckets will be empty.
@@ -349,7 +359,7 @@ void LSHSearch<SortPolicy>::BuildHash()
     // For L2 metric, 2-stable distributions are used, and
     // the normal Z ~ N(0, 1) is a 2-stable distribution.
     arma::mat projMat;
-    projMat.randn(referenceSet.n_rows, numProj);
+    projMat.randn(referenceSet->n_rows, numProj);
 
     // Save the projection matrix for querying.
     projections.push_back(projMat);
@@ -366,8 +376,8 @@ void LSHSearch<SortPolicy>::BuildHash()
     // point is obtained as:
     // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
     arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
-                                       referenceSet.n_cols);
-    arma::mat hashMat = projMat.t() * referenceSet;
+                                       referenceSet->n_cols);
+    arma::mat hashMat = projMat.t() * (*referenceSet);
     hashMat += offsetMat;
     hashMat /= hashWidth;
 
@@ -380,7 +390,7 @@ void LSHSearch<SortPolicy>::BuildHash()
     for (size_t j = 0; j < secondHashVec.n_elem; j++)
       secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
 
-    Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+    Log::Assert(secondHashVec.n_elem == referenceSet->n_cols);
 
     // Insert the point in the corresponding row to its bucket in the
     // 'secondHashTable'.
@@ -433,15 +443,15 @@ std::string LSHSearch<SortPolicy>::ToString() const
 {
   std::ostringstream convert;
   convert << "LSHSearch [" << this << "]" << std::endl;
-  convert << "  Reference Set: " << referenceSet.n_rows << "x" ;
-  convert <<  referenceSet.n_cols << std::endl;
+  convert << "  Reference set: " << referenceSet->n_rows << "x" ;
+  convert <<  referenceSet->n_cols << std::endl;
   convert << "  Number of Projections: " << numProj << std::endl;
   convert << "  Number of Tables: " << numTables << std::endl;
   convert << "  Hash Width: " << hashWidth << std::endl;
   return convert.str();
 }
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list