[mlpack-git] master: Adds code that gives controllable access to LSH projection tables (e94896d)

gitdub at mlpack.org gitdub at mlpack.org
Tue May 31 16:14:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit e94896d9c720ed706b5ad546be9df13b90631f10
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Tue May 31 23:14:08 2016 +0300

    Adds code that gives controllable access to LSH projection tables


>---------------------------------------------------------------

e94896d9c720ed706b5ad546be9df13b90631f10
 src/mlpack/methods/lsh/lsh_search.hpp      | 30 +++++++++++++++++++----
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 38 ++++++++++++++++++++++++++----
 2 files changed, 60 insertions(+), 8 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 7505f29..94ab452 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -83,15 +83,19 @@ class LSHSearch
   ~LSHSearch();
 
   /**
-   * Train the LSH model on the given dataset.  This means building new hash
-   * tables.
+   * Train the LSH model on the given dataset.  If a correct vector is not
+   * provided, this means building new hash tables. Otherwise, we use the ones
+   * provided by the user.
    */
   void Train(const arma::mat& referenceSet,
              const size_t numProj,
              const size_t numTables,
              const double hashWidth = 0.0,
              const size_t secondHashSize = 99901,
-             const size_t bucketSize = 500);
+             const size_t bucketSize = 500,
+             const std::vector<arma::mat> &projection
+             = std::vector<arma::mat>()
+             );
 
   /**
    * Compute the nearest neighbors of the points in the given query set and
@@ -174,6 +178,24 @@ class LSHSearch
   //! Get the second hash table.
   const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
 
+  //! Get the projection tables.
+  std::vector<arma::mat> getProjectionTables() { return projections; }
+
+  //! Change the projection tables (Retrains object)
+  void setProjectionTables(std::vector<arma::mat> projTables)
+  {
+    // Simply call Train() with given projection tables
+    Train(
+        *referenceSet,
+        numProj,
+        numTables,
+        hashWidth,
+        secondHashSize,
+        bucketSize,
+        projTables
+        );
+  };
+
  private:
   /**
    * This function builds a hash table with two levels of hashing as presented
@@ -188,7 +210,7 @@ class LSHSearch
    * are private members of this class, initialized during the class
    * initialization.
    */
-  void BuildHash();
+  void BuildHash(const std::vector<arma::mat> &projection);
 
   /**
    * This function takes a query and hashes it into each of the hash tables to
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b956ed2..119eb78 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -65,7 +65,8 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
                                   const size_t numTables,
                                   const double hashWidthIn,
                                   const size_t secondHashSize,
-                                  const size_t bucketSize)
+                                  const size_t bucketSize,
+                                  const std::vector<arma::mat> &projection)
 {
   // Set new reference set.
   if (this->referenceSet && ownsSet)
@@ -97,7 +98,7 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
 
   Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
 
-  BuildHash();
+  BuildHash(projection);
 }
 
 template<typename SortPolicy>
@@ -355,7 +356,7 @@ Search(const size_t k,
 }
 
 template<typename SortPolicy>
-void LSHSearch<SortPolicy>::BuildHash()
+void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
 {
   // The first level hash for a single table outputs a 'numProj'-dimensional
   // integer key for each point in the set -- (key, pointID)
@@ -412,6 +413,13 @@ void LSHSearch<SortPolicy>::BuildHash()
   // Step III: Create each hash table in the first level hash one by one and
   // putting them directly into the 'secondHashTable' for memory efficiency.
   projections.clear(); // Reset projections vector.
+
+
+  if (projection.size() != 0 && projection.size() != numTables)
+    throw std::invalid_argument(
+        "number of projection tables provided must be equal to numProj"
+        );
+
   for (size_t i = 0; i < numTables; i++)
   {
     // Step IV: Obtain the 'numProj' projections for each table.
@@ -419,7 +427,29 @@ void LSHSearch<SortPolicy>::BuildHash()
     // For L2 metric, 2-stable distributions are used, and
     // the normal Z ~ N(0, 1) is a 2-stable distribution.
     arma::mat projMat;
-    projMat.randn(referenceSet->n_rows, numProj);
+
+    if (projection.size() == 0) //random generation of table i
+    {
+
+      // For L2 metric, p-stable distributions are used, and the normal
+      // Z ~ N(0, 1) is p-stable.
+      projMat.randn(referenceSet->n_rows, numProj);
+    }
+    else //user-specified projection tables
+    {
+      //TODO: check that projection.size() == numTables
+
+      projMat = projection[i];
+
+      //make sure specified matrix is of correct size
+      if (projMat.n_rows != referenceSet->n_rows)
+        throw std::invalid_argument( 
+            "projection table dimensionality doesn't"
+            " equal dataset dimensionality" );
+      if (projMat.n_cols != numProj)
+        throw std::invalid_argument(
+            "projection table doesn't have correct number of projections");
+    }
 
     // Save the projection matrix for querying.
     projections.push_back(projMat);




More information about the mlpack-git mailing list