[mlpack-git] master: Refactor LSH. Don't used squared distances. (9c34c98)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue May 5 15:38:24 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6137e52d32c1338b28853afd059b67cf68a50270...9c34c98c8dad01b4ad659c37367a8b20287b6b5a
>---------------------------------------------------------------
commit 9c34c98c8dad01b4ad659c37367a8b20287b6b5a
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue May 5 15:37:59 2015 -0400
Refactor LSH. Don't used squared distances.
>---------------------------------------------------------------
9c34c98c8dad01b4ad659c37367a8b20287b6b5a
src/mlpack/methods/lsh/lsh_search.hpp | 44 +++++++++----------
src/mlpack/methods/lsh/lsh_search_impl.hpp | 69 +++++++++++++++---------------
2 files changed, 56 insertions(+), 57 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 83c6bd6..578c449 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -166,16 +166,23 @@ class LSHSearch
* This is a helper function that computes the distance of the query to the
* neighbor candidates and appropriately stores the best 'k' candidates
*
+ * @param distances Matrix holding output distances.
+ * @param neighbors Matrix holding output neighbors.
* @param queryIndex The index of the query in question
* @param referenceIndex The index of the neighbor candidate in question
*/
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+ double BaseCase(arma::mat& distances,
+ arma::Mat<size_t>& neighbors,
+ const size_t queryIndex,
+ const size_t referenceIndex);
/**
* This is a helper function that efficiently inserts better neighbor
* candidates into an existing set of neighbor candidates. This function is
* only called by the 'BaseCase' function.
*
+ * @param distances Matrix holding output distances.
+ * @param neighbors Matrix holding output neighbors.
* @param queryIndex This is the index of the query being processed currently
* @param pos The position of the neighbor candidate in the current list of
* neighbor candidates.
@@ -183,42 +190,41 @@ class LSHSearch
* of the best 'k' candidates for the query in question.
* @param distance The distance of the query to the neighbor candidate.
*/
- void InsertNeighbor(const size_t queryIndex, const size_t pos,
- const size_t neighbor, const double distance);
+ void InsertNeighbor(arma::mat& distances,
+ arma::Mat<size_t>& neighbors,
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
//! Reference dataset.
const arma::mat& referenceSet;
-
//! Query dataset (may not be given).
const arma::mat& querySet;
- //! The number of projections
+ //! The number of projections.
const size_t numProj;
-
- //! The number of hash tables
+ //! The number of hash tables.
const size_t numTables;
- //! The std::vector containing the projection matrix of each table
+ //! The std::vector containing the projection matrix of each table.
std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
- //! The list of the offset 'b' for each of the projection for each table
+ //! The list of the offsets 'b' for each of the projection for each table.
arma::mat offsets; // should be numProj x numTables
- //! The hash width
+ //! The hash width.
double hashWidth;
- //! The big prime representing the size of the second hash
+ //! The big prime representing the size of the second hash.
const size_t secondHashSize;
- //! The weights of the second hash
+ //! The weights of the second hash.
arma::vec secondHashWeights;
- //! The bucket size of the second hash
+ //! The bucket size of the second hash.
const size_t bucketSize;
- //! Instantiation of the metric.
- metric::SquaredEuclideanDistance metric;
-
//! The final hash table; should be (< secondHashSize) x bucketSize.
arma::Mat<size_t> secondHashTable;
@@ -230,12 +236,6 @@ class LSHSearch
//! corresponding to this value. Should be secondHashSize.
arma::Col<size_t> bucketRowInHashTable;
- //! The pointer to the nearest neighbor distances.
- arma::mat* distancePtr;
-
- //! The pointer to the nearest neighbor indices.
- arma::Mat<size_t>* neighborPtr;
-
//! The number of distance evaluations.
size_t distanceEvaluations;
}; // class LSHSearch
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b6110a3..cd13557 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -39,8 +39,8 @@ LSHSearch(const arma::mat& referenceSet,
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2)));
+ hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
+ referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
}
hashWidth /= 25;
@@ -76,8 +76,8 @@ LSHSearch(const arma::mat& referenceSet,
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2)));
+ hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
+ referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
}
hashWidth /= 25;
@@ -89,53 +89,56 @@ LSHSearch(const arma::mat& referenceSet,
}
template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
+void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
+ arma::Mat<size_t>& neighbors,
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
{
// We only memmove() if there is actually a need to shift something.
- if (pos < (distancePtr->n_rows - 1))
+ if (pos < (distances.n_rows - 1))
{
- int len = (distancePtr->n_rows - 1) - pos;
- memmove(distancePtr->colptr(queryIndex) + (pos + 1),
- distancePtr->colptr(queryIndex) + pos,
+ const size_t len = (distances.n_rows - 1) - pos;
+ memmove(distances.colptr(queryIndex) + (pos + 1),
+ distances.colptr(queryIndex) + pos,
sizeof(double) * len);
- memmove(neighborPtr->colptr(queryIndex) + (pos + 1),
- neighborPtr->colptr(queryIndex) + pos,
+ memmove(neighbors.colptr(queryIndex) + (pos + 1),
+ neighbors.colptr(queryIndex) + pos,
sizeof(size_t) * len);
}
// Now put the new information in the right index.
- (*distancePtr)(pos, queryIndex) = distance;
- (*neighborPtr)(pos, queryIndex) = neighbor;
+ distances(pos, queryIndex) = distance;
+ neighbors(pos, queryIndex) = neighbor;
}
template<typename SortPolicy>
inline force_inline
-double LSHSearch<SortPolicy>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
+ arma::Mat<size_t>& neighbors,
+ const size_t queryIndex,
+ const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
return 0.0;
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
- ++distanceEvaluations;
+ const double distance = metric::EuclideanDistance::Evaluate(
+ querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distancePtr->unsafe_col(queryIndex);
- arma::Col<size_t> queryIndices = neighborPtr->unsafe_col(queryIndex);
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
distance);
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+ InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+ referenceIndex, distance);
return distance;
}
@@ -211,14 +214,11 @@ Search(const size_t k,
arma::mat& distances,
const size_t numTablesToSearch)
{
- neighborPtr = &resultingNeighbors;
- distancePtr = &distances;
-
// Set the size of the neighbor and distance matrices.
- neighborPtr->set_size(k, querySet.n_cols);
- distancePtr->set_size(k, querySet.n_cols);
- distancePtr->fill(SortPolicy::WorstDistance());
- neighborPtr->fill(referenceSet.n_cols);
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+ distances.fill(SortPolicy::WorstDistance());
+ resultingNeighbors.fill(referenceSet.n_cols);
size_t avgIndicesReturned = 0;
@@ -239,11 +239,12 @@ Search(const size_t k,
// Sequentially go through all the candidates and save the best 'k'
// candidates.
for (size_t j = 0; j < refIndices.n_elem; j++)
- BaseCase(i, (size_t) refIndices[j]);
+ BaseCase(distances, resultingNeighbors, i, (size_t) refIndices[j]);
}
Timer::Stop("computing_neighbors");
+ distanceEvaluations += avgIndicesReturned;
avgIndicesReturned /= querySet.n_cols;
Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
std::endl;
@@ -405,8 +406,6 @@ std::string LSHSearch<SortPolicy>::ToString() const
convert << " Number of Projections: " << numProj << std::endl;
convert << " Number of Tables: " << numTables << std::endl;
convert << " Hash Width: " << hashWidth << std::endl;
- convert << " Metric: " << std::endl;
- convert << mlpack::util::Indent(metric.ToString(),2);
return convert.str();
}
More information about the mlpack-git
mailing list