[mlpack-git] master: Use a pointer to the reference set, so it can be changed. (9b75b38)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:22 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d
>---------------------------------------------------------------
commit 9b75b383d4a2f86666871a36ae63e5c3371f70c7
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 9 22:48:18 2015 +0000
Use a pointer to the reference set, so it can be changed.
>---------------------------------------------------------------
9b75b383d4a2f86666871a36ae63e5c3371f70c7
src/mlpack/methods/lsh/lsh_search.hpp | 9 ++++-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 56 ++++++++++++++++++------------
2 files changed, 41 insertions(+), 24 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index c113b10..e07559c 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -72,6 +72,11 @@ class LSHSearch
const size_t bucketSize = 500);
/**
+ * Clean memory.
+ */
+ ~LSHSearch();
+
+ /**
* Compute the nearest neighbors of the points in the given query set and
* store the output in the given matrices. The matrices will be set to the
* size of n columns by k rows, where n is the number of points in the query
@@ -215,7 +220,9 @@ class LSHSearch
const double distance) const;
//! Reference dataset.
- const arma::mat& referenceSet;
+ const arma::mat* referenceSet;
+ //! If true, we own the reference set.
+ bool ownsSet;
//! The number of projections.
const size_t numProj;
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 5760f89..2463751 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -21,7 +21,8 @@ LSHSearch(const arma::mat& referenceSet,
const double hashWidthIn,
const size_t secondHashSize,
const size_t bucketSize) :
- referenceSet(referenceSet),
+ referenceSet(&referenceSet),
+ ownsSet(false),
numProj(numProj),
numTables(numTables),
hashWidth(hashWidthIn),
@@ -49,6 +50,14 @@ LSHSearch(const arma::mat& referenceSet,
BuildHash();
}
+// Destructor.
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::~LSHSearch()
+{
+ if (ownsSet)
+ delete referenceSet;
+}
+
template<typename SortPolicy>
void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
arma::Mat<size_t>& neighbors,
@@ -88,8 +97,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
return;
const double distance = metric::EuclideanDistance::Evaluate(
- referenceSet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
+ referenceSet->unsafe_col(queryIndex),
+ referenceSet->unsafe_col(referenceIndex));
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
@@ -114,7 +123,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
arma::mat& distances) const
{
const double distance = metric::EuclideanDistance::Evaluate(
- querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+ querySet.unsafe_col(queryIndex),
+ referenceSet->unsafe_col(referenceIndex));
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
@@ -169,7 +179,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
// For all the buckets that the query is hashed into, sequentially
// collect the indices in those buckets.
arma::Col<size_t> refPointsConsidered;
- refPointsConsidered.zeros(referenceSet.n_cols);
+ refPointsConsidered.zeros(referenceSet->n_cols);
for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
{
@@ -199,16 +209,16 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
const size_t numTablesToSearch)
{
// Ensure the dimensionality of the query set is correct.
- if (querySet.n_rows != referenceSet.n_rows)
+ if (querySet.n_rows != referenceSet->n_rows)
Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
<< querySet.n_rows << ") is not equal to the dimensionality the model "
- << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
+ << "was trained on (" << referenceSet->n_rows << ")!" << std::endl;
// Set the size of the neighbor and distance matrices.
resultingNeighbors.set_size(k, querySet.n_cols);
distances.set_size(k, querySet.n_cols);
distances.fill(SortPolicy::WorstDistance());
- resultingNeighbors.fill(referenceSet.n_cols);
+ resultingNeighbors.fill(referenceSet->n_cols);
size_t avgIndicesReturned = 0;
@@ -250,22 +260,22 @@ Search(const size_t k,
const size_t numTablesToSearch)
{
// This is monochromatic search; the query set is the reference set.
- resultingNeighbors.set_size(k, referenceSet.n_cols);
- distances.set_size(k, referenceSet.n_cols);
+ resultingNeighbors.set_size(k, referenceSet->n_cols);
+ distances.set_size(k, referenceSet->n_cols);
distances.fill(SortPolicy::WorstDistance());
- resultingNeighbors.fill(referenceSet.n_cols);
+ resultingNeighbors.fill(referenceSet->n_cols);
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
// Go through every query point sequentially.
- for (size_t i = 0; i < referenceSet.n_cols; i++)
+ for (size_t i = 0; i < referenceSet->n_cols; i++)
{
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
- ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch);
+ ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
// An informative book-keeping for the number of neighbor candidates
// returned on average.
@@ -280,7 +290,7 @@ Search(const size_t k,
Timer::Stop("computing_neighbors");
distanceEvaluations += avgIndicesReturned;
- avgIndicesReturned /= referenceSet.n_cols;
+ avgIndicesReturned /= referenceSet->n_cols;
Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
std::endl;
}
@@ -318,7 +328,7 @@ void LSHSearch<SortPolicy>::BuildHash()
// point has index 'n' so the presence of this in the bucket denotes that
// there are no more points in this bucket.
secondHashTable.set_size(secondHashSize, bucketSize);
- secondHashTable.fill(referenceSet.n_cols);
+ secondHashTable.fill(referenceSet->n_cols);
// Keep track of the size of each bucket in the hash. At the end of hashing
// most buckets will be empty.
@@ -349,7 +359,7 @@ void LSHSearch<SortPolicy>::BuildHash()
// For L2 metric, 2-stable distributions are used, and
// the normal Z ~ N(0, 1) is a 2-stable distribution.
arma::mat projMat;
- projMat.randn(referenceSet.n_rows, numProj);
+ projMat.randn(referenceSet->n_rows, numProj);
// Save the projection matrix for querying.
projections.push_back(projMat);
@@ -366,8 +376,8 @@ void LSHSearch<SortPolicy>::BuildHash()
// point is obtained as:
// key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
- referenceSet.n_cols);
- arma::mat hashMat = projMat.t() * referenceSet;
+ referenceSet->n_cols);
+ arma::mat hashMat = projMat.t() * (*referenceSet);
hashMat += offsetMat;
hashMat /= hashWidth;
@@ -380,7 +390,7 @@ void LSHSearch<SortPolicy>::BuildHash()
for (size_t j = 0; j < secondHashVec.n_elem; j++)
secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
- Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+ Log::Assert(secondHashVec.n_elem == referenceSet->n_cols);
// Insert the point in the corresponding row to its bucket in the
// 'secondHashTable'.
@@ -433,15 +443,15 @@ std::string LSHSearch<SortPolicy>::ToString() const
{
std::ostringstream convert;
convert << "LSHSearch [" << this << "]" << std::endl;
- convert << " Reference Set: " << referenceSet.n_rows << "x" ;
- convert << referenceSet.n_cols << std::endl;
+ convert << " Reference set: " << referenceSet->n_rows << "x" ;
+ convert << referenceSet->n_cols << std::endl;
convert << " Number of Projections: " << numProj << std::endl;
convert << " Number of Tables: " << numTables << std::endl;
convert << " Hash Width: " << hashWidth << std::endl;
return convert.str();
}
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list