[mlpack-git] master: Changes LSHSearch.projections from std::vector<arma::mat> to arma::cube (b067e89)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 1 03:40:20 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit b067e89b528d7fb55b3591b64d49531c039d42df
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Wed Jun 1 10:40:20 2016 +0300
Changes LSHSearch.projections from std::vector<arma::mat> to arma::cube
>---------------------------------------------------------------
b067e89b528d7fb55b3591b64d49531c039d42df
src/mlpack/methods/lsh/lsh_search.hpp | 14 +++----
src/mlpack/methods/lsh/lsh_search_impl.hpp | 66 ++++++++++++++----------------
2 files changed, 37 insertions(+), 43 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index d5389d4..10a6ffc 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -93,8 +93,8 @@ class LSHSearch
const double hashWidth = 0.0,
const size_t secondHashSize = 99901,
const size_t bucketSize = 500,
- const std::vector<arma::mat> &projection
- = std::vector<arma::mat>()
+ const arma::cube &projection
+ = arma::zeros<arma::cube>(0,0,0)
);
/**
@@ -163,8 +163,6 @@ class LSHSearch
//! Get the number of projections.
size_t NumProjections() const { return projections.size(); }
- //! Get the projection matrix of the given table.
- const arma::mat& Projection(const size_t i) const { return projections[i]; }
//! Get the offsets 'b' for each of the projections. (One 'b' per column.)
const arma::mat& Offsets() const { return offsets; }
@@ -179,10 +177,10 @@ class LSHSearch
const arma::Mat<size_t>& SecondHashTable() const { return secondHashTable; }
//! Get the projection tables.
- const std::vector<arma::mat> Projections() { return projections; }
+ const arma::cube Projections() { return projections; }
//! Change the projection tables (Retrains object)
- void Projections(const std::vector<arma::mat> &projTables)
+ void Projections(const arma::cube &projTables)
{
// Simply call Train() with given projection tables
Train(
@@ -210,7 +208,7 @@ class LSHSearch
* are private members of this class, initialized during the class
* initialization.
*/
- void BuildHash(const std::vector<arma::mat> &projection);
+ void BuildHash(const arma::cube &projection);
/**
* This function takes a query and hashes it into each of the hash tables to
@@ -294,7 +292,7 @@ class LSHSearch
size_t numTables;
//! The std::vector containing the projection matrix of each table.
- std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
+ arma::cube projections; // should be [numProj x dims] x numTables
//! The list of the offsets 'b' for each of the projection for each table.
arma::mat offsets; // should be numProj x numTables
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 7986c07..14a86e1 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -66,7 +66,7 @@ void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
const double hashWidthIn,
const size_t secondHashSize,
const size_t bucketSize,
- const std::vector<arma::mat> &projection)
+ const arma::cube &projection)
{
// Set new reference set.
if (this->referenceSet && ownsSet)
@@ -206,7 +206,8 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
// Compute the projection of the query in each table.
arma::mat allProjInTables(numProj, numTablesToSearch);
for (size_t i = 0; i < numTablesToSearch; i++)
- allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
+ //allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
+ allProjInTables.unsafe_col(i) = projections.slice(i).t() * queryPoint;
allProjInTables += offsets.cols(0, numTablesToSearch - 1);
allProjInTables /= hashWidth;
@@ -356,7 +357,7 @@ Search(const size_t k,
}
template<typename SortPolicy>
-void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
+void LSHSearch<SortPolicy>::BuildHash(const arma::cube &projection)
{
// The first level hash for a single table outputs a 'numProj'-dimensional
// integer key for each point in the set -- (key, pointID)
@@ -412,45 +413,39 @@ void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
// Step III: Create each hash table in the first level hash one by one and
// putting them directly into the 'secondHashTable' for memory efficiency.
- projections.clear(); // Reset projections vector.
+ //projections.clear(); // Reset projections vector.
- if (projection.size() != 0 && projection.size() != numTables)
- throw std::invalid_argument(
- "number of projection tables provided must be equal to numProj"
- );
- for (size_t i = 0; i < numTables; i++)
- {
- // Step IV: Obtain the 'numProj' projections for each table.
+ // Step IV: Obtain the 'numProj' projections for each table.
+ if (projection.n_slices == 0) //random generation of tables
+ {
// For L2 metric, 2-stable distributions are used, and
// the normal Z ~ N(0, 1) is a 2-stable distribution.
- arma::mat projMat;
- if (projection.size() == 0) //random generation of table i
- {
+ //numTables random tables arranged in a cube
+ projections.randn(
+ referenceSet->n_rows,
+ numProj,
+ numTables
+ );
+ }
+ else if (projection.n_slices == numTables) //user defined tables
+ {
+ projections = projection;
+ }
+ else //invalid argument
+ {
+ throw std::invalid_argument(
+ "number of projection tables provided must be equal to numProj"
+ );
+ }
- // For L2 metric, p-stable distributions are used, and the normal
- // Z ~ N(0, 1) is p-stable.
- projMat.randn(referenceSet->n_rows, numProj);
- }
- else //user-specified projection tables
- {
- projMat = projection[i];
-
- //make sure specified matrix is of correct size
- if ( projMat.n_rows != referenceSet->n_rows )
- throw std::invalid_argument(
- "projection table dimensionality doesn't"
- " equal dataset dimensionality" );
- if ( projMat.n_cols != numProj )
- throw std::invalid_argument(
- "projection table doesn't have correct number of projections");
- }
- // Save the projection matrix for querying.
- projections.push_back(projMat);
+ for (size_t i = 0; i < numTables; i++)
+ {
+
// Step V: create the 'numProj'-dimensional key for each point in each
// table.
@@ -465,7 +460,8 @@ void LSHSearch<SortPolicy>::BuildHash(const std::vector<arma::mat> &projection)
// key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
referenceSet->n_cols);
- arma::mat hashMat = projMat.t() * (*referenceSet);
+ // arma::mat hashMat = projMat.t() * (*referenceSet);
+ arma::mat hashMat = projections.slice(i).t() * (*referenceSet);
hashMat += offsetMat;
hashMat /= hashWidth;
@@ -546,7 +542,7 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
// Delete existing projections, if necessary.
if (Archive::is_loading::value)
- projections.clear();
+ projections.zeros(0, 0, 0); // TODO: correct way to clear this?
ar & CreateNVP(projections, "projections");
ar & CreateNVP(offsets, "offsets");
More information about the mlpack-git
mailing list