[mlpack-git] master: Adds code that gives controllable access to LSH projection tables (e94896d)
gitdub at mlpack.org
gitdub at mlpack.org
Tue May 31 16:14:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eba4f9924694bc10daec74ff5059dbb8af001416...e3a23c256f017ebb8185b15847c82f51d359cdfd
>---------------------------------------------------------------
commit e94896d9c720ed706b5ad546be9df13b90631f10
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Tue May 31 23:14:08 2016 +0300
Adds code that gives controllable access to LSH projection tables
>---------------------------------------------------------------
e94896d9c720ed706b5ad546be9df13b90631f10
src/mlpack/methods/lsh/lsh_search.hpp | 30 +++++++++++++++++++----
src/mlpack/methods/lsh/lsh_search_impl.hpp | 38 ++++++++++++++++++++++++++----
2 files changed, 60 insertions(+), 8 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 7505f29..94ab452 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -83,15 +83,19 @@ class LSHSearch
~LSHSearch();
/**
- * Train the LSH model on the given dataset. This means building new hash
- * tables.
+ * Train the LSH model on the given dataset. If a correct vector is not
+ * provided, this means building new hash tables. Otherwise, we use the ones
+ * provided by the user.
*/
void Train(const arma::mat& referenceSet,
const size_t numProj,
const size_t numTables,
const double hashWidth = 0.0,
const size_t secondHashSize = 99901,
- const size_t bucketSize = 500);
+ const size_t bucketSize = 500,
+ const std::vector<arma::mat> &projection
+ = std::vector<arma::mat>()
+ );
/**
* Compute the nearest neighbors of the points in the given query set and
@@ -174,6 +178,24 @@ class LSHSearch
//! Get the second hash table.
const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
+ //! Get the projection tables.
+ std::vector<arma::mat> getProjectionTables() { return projections; }
+
+ //! Change the projection tables (Retrains object)
+ void setProjectionTables(std::vector<arma::mat> projTables)
+ {
+ // Simply call Train() with given projection tables
+ Train(
+ *referenceSet,
+ numProj,
+ numTables,
+ hashWidth,
+ secondHashSize,
+ bucketSize,
+ projTables
+ );
+ };
+
private:
/**
* This function builds a hash table with two levels of hashing as presented
@@ -188,7 +210,7 @@ class LSHSearch
* are private members of this class, initialized during the class
* initialization.
*/
- void BuildHash();
+ void BuildHash(const std::vector<arma::mat> &projection);
/**
* This function takes a query and hashes it into each of the hash tables to
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b956ed2..119eb78 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -65,7 +65,8 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
const size_t numTables,
const double hashWidthIn,
const size_t secondHashSize,
- const size_t bucketSize)
+ const size_t bucketSize,
+ const std::vector<arma::mat> &projection)
{
// Set new reference set.
if (this->referenceSet && ownsSet)
@@ -97,7 +98,7 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
- BuildHash();
+ BuildHash(projection);
}
template<typename SortPolicy>
@@ -355,7 +356,7 @@ Search(const size_t k,
}
template<typename SortPolicy>
-void LSHSearch<SortPolicy>::BuildHash()
+void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
{
// The first level hash for a single table outputs a 'numProj'-dimensional
// integer key for each point in the set -- (key, pointID)
@@ -412,6 +413,13 @@ void LSHSearch<SortPolicy>::BuildHash()
// Step III: Create each hash table in the first level hash one by one and
// putting them directly into the 'secondHashTable' for memory efficiency.
projections.clear(); // Reset projections vector.
+
+
+ if (projection.size() != 0 && projection.size() != numTables)
+ throw std::invalid_argument(
+ "number of projection tables provided must be equal to numProj"
+ );
+
for (size_t i = 0; i < numTables; i++)
{
// Step IV: Obtain the 'numProj' projections for each table.
@@ -419,7 +427,29 @@ void LSHSearch<SortPolicy>::BuildHash()
// For L2 metric, 2-stable distributions are used, and
// the normal Z ~ N(0, 1) is a 2-stable distribution.
arma::mat projMat;
- projMat.randn(referenceSet->n_rows, numProj);
+
+ if (projection.size() == 0) //random generation of table i
+ {
+
+ // For L2 metric, p-stable distributions are used, and the normal
+ // Z ~ N(0, 1) is p-stable.
+ projMat.randn(referenceSet->n_rows, numProj);
+ }
+ else //user-specified projection tables
+ {
+ //TODO: check that projection.size() == numTables
+
+ projMat = projection[i];
+
+ //make sure specified matrix is of correct size
+ if (projMat.n_rows != referenceSet->n_rows)
+ throw std::invalid_argument(
+ "projection table dimensionality doesn't"
+ " equal dataset dimensionality" );
+ if (projMat.n_cols != numProj)
+ throw std::invalid_argument(
+ "projection table doesn't have correct number of projections");
+ }
// Save the projection matrix for querying.
projections.push_back(projMat);
More information about the mlpack-git
mailing list