[mlpack-git] master: Changes LSHSearch.projections from std::vector<arma::mat> to arma::cube (b067e89)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 1 03:40:20 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit b067e89b528d7fb55b3591b64d49531c039d42df
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Wed Jun 1 10:40:20 2016 +0300

    Changes LSHSearch.projections from std::vector<arma::mat> to arma::cube


>---------------------------------------------------------------

b067e89b528d7fb55b3591b64d49531c039d42df
 src/mlpack/methods/lsh/lsh_search.hpp      | 14 +++----
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 66 ++++++++++++++----------------
 2 files changed, 37 insertions(+), 43 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index d5389d4..10a6ffc 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -93,8 +93,8 @@ class LSHSearch
              const double hashWidth = 0.0,
              const size_t secondHashSize = 99901,
              const size_t bucketSize = 500,
-             const std::vector<arma::mat> &projection
-             = std::vector<arma::mat>()
+             const arma::cube &projection
+             = arma::zeros<arma::cube>(0,0,0)
              );
 
   /**
@@ -163,8 +163,6 @@ class LSHSearch
 
   //! Get the number of projections.
   size_t NumProjections() const { return projections.size(); }
-  //! Get the projection matrix of the given table.
-  const arma::mat& Projection(const size_t i) const { return projections[i]; }
 
   //! Get the offsets 'b' for each of the projections.  (One 'b' per column.)
   const arma::mat& Offsets() const { return offsets; }
@@ -179,10 +177,10 @@ class LSHSearch
   const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
 
   //! Get the projection tables.
-  const std::vector<arma::mat> Projections() { return projections; }
+  const arma::cube Projections() { return projections; }
 
   //! Change the projection tables (Retrains object)
-  void Projections(const std::vector<arma::mat> &projTables)
+  void Projections(const arma::cube &projTables)
   {
     // Simply call Train() with given projection tables
     Train(
@@ -210,7 +208,7 @@ class LSHSearch
    * are private members of this class, initialized during the class
    * initialization.
    */
-  void BuildHash(const std::vector<arma::mat> &projection);
+  void BuildHash(const arma::cube &projection);
 
   /**
    * This function takes a query and hashes it into each of the hash tables to
@@ -294,7 +292,7 @@ class LSHSearch
   size_t numTables;
 
   //! The std::vector containing the projection matrix of each table.
-  std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
+  arma::cube projections; // should be [numProj x dims] x numTables
 
   //! The list of the offsets 'b' for each of the projection for each table.
   arma::mat offsets; // should be numProj x numTables
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 7986c07..14a86e1 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -66,7 +66,7 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
                                   const double hashWidthIn,
                                   const size_t secondHashSize,
                                   const size_t bucketSize,
-                                  const std::vector<arma::mat> &projection)
+                                  const arma::cube &projection)
 {
   // Set new reference set.
   if (this->referenceSet && ownsSet)
@@ -206,7 +206,8 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
   // Compute the projection of the query in each table.
   arma::mat allProjInTables(numProj, numTablesToSearch);
   for (size_t i = 0; i < numTablesToSearch; i++)
-    allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
+    //allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
+    allProjInTables.unsafe_col(i) = projections.slice(i).t() * queryPoint;
   allProjInTables += offsets.cols(0, numTablesToSearch - 1);
   allProjInTables /= hashWidth;
 
@@ -356,7 +357,7 @@ Search(const size_t k,
 }
 
 template<typename SortPolicy>
-void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
+void LSHSearch<SortPolicy>::BuildHash(const arma::cube &projection)
 {
   // The first level hash for a single table outputs a 'numProj'-dimensional
   // integer key for each point in the set -- (key, pointID)
@@ -412,45 +413,39 @@ void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
 
   // Step III: Create each hash table in the first level hash one by one and
   // putting them directly into the 'secondHashTable' for memory efficiency.
-  projections.clear(); // Reset projections vector.
+  //projections.clear(); // Reset projections vector.
 
 
-  if (projection.size() != 0 && projection.size() != numTables)
-    throw std::invalid_argument(
-        "number of projection tables provided must be equal to numProj"
-        );
 
-  for (size_t i = 0; i < numTables; i++)
-  {
-    // Step IV: Obtain the 'numProj' projections for each table.
+  // Step IV: Obtain the 'numProj' projections for each table.
 
+  if (projection.n_slices == 0) //random generation of tables
+  {
     // For L2 metric, 2-stable distributions are used, and
     // the normal Z ~ N(0, 1) is a 2-stable distribution.
-    arma::mat projMat;
 
-    if (projection.size() == 0) //random generation of table i
-    {
+    //numTables random tables arranged in a cube
+    projections.randn(
+        referenceSet->n_rows,
+        numProj,
+        numTables
+    );
+  }
+  else if (projection.n_slices == numTables) //user defined tables
+  {
+    projections = projection;
+  }
+  else //invalid argument
+  {
+    throw std::invalid_argument(
+        "number of projection tables provided must be equal to numProj"
+        );
+  }
     
-      // For L2 metric, p-stable distributions are used, and the normal
-      // Z ~ N(0, 1) is p-stable.
-      projMat.randn(referenceSet->n_rows, numProj);
-    }
-    else //user-specified projection tables
-    {
-      projMat = projection[i];
-
-      //make sure specified matrix is of correct size
-      if ( projMat.n_rows != referenceSet->n_rows )
-        throw std::invalid_argument( 
-            "projection table dimensionality doesn't"
-            " equal dataset dimensionality" );
-      if ( projMat.n_cols != numProj )
-        throw std::invalid_argument(
-            "projection table doesn't have correct number of projections");
-    }
 
-    // Save the projection matrix for querying.
-    projections.push_back(projMat);
+  for (size_t i = 0; i < numTables; i++)
+  {
+
 
     // Step V: create the 'numProj'-dimensional key for each point in each
     // table.
@@ -465,7 +460,8 @@ void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
     // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
     arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
                                        referenceSet->n_cols);
-    arma::mat hashMat = projMat.t() * (*referenceSet);
+    // arma::mat hashMat = projMat.t() * (*referenceSet);
+    arma::mat hashMat = projections.slice(i).t() * (*referenceSet);
     hashMat += offsetMat;
     hashMat /= hashWidth;
 
@@ -546,7 +542,7 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
 
   // Delete existing projections, if necessary.
   if (Archive::is_loading::value)
-    projections.clear();
+    projections.zeros(0, 0, 0); // TODO: correct way to clear this?
 
   ar & CreateNVP(projections, "projections");
   ar & CreateNVP(offsets, "offsets");




More information about the mlpack-git mailing list