[mlpack-git] master: Make query set a parameter to Search(). (78cc694)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 9 16:30:41 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9bd2063f96de9430b387974e7ce7204a1e57a803...78cc694a4fd50a68a24f5ab9af7531873566b3ba
>---------------------------------------------------------------
commit 78cc694a4fd50a68a24f5ab9af7531873566b3ba
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 9 20:53:54 2015 +0000
Make query set a parameter to Search().
No need to hold it internally.
>---------------------------------------------------------------
78cc694a4fd50a68a24f5ab9af7531873566b3ba
src/mlpack/methods/lsh/lsh_main.cpp | 13 ++-
src/mlpack/methods/lsh/lsh_search.hpp | 97 ++++++++++-------
src/mlpack/methods/lsh/lsh_search_impl.hpp | 168 +++++++++++++++++------------
src/mlpack/tests/lsh_test.cpp | 4 +-
4 files changed, 164 insertions(+), 118 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 9a5d37d..6fe63e3 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -128,18 +128,17 @@ int main(int argc, char *argv[])
LSHSearch<>* allkann;
- if (CLI::GetParam<string>("query_file") != "")
- allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
- hashWidth, secondHashSize, bucketSize);
- else
- allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
- secondHashSize, bucketSize);
+ allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
+ secondHashSize, bucketSize);
Timer::Stop("hash_building");
Log::Info << "Computing " << k << " distance approximate nearest neighbors "
<< endl;
- allkann->Search(k, neighbors, distances);
+ if (CLI::HasParam("query_file"))
+ allkann->Search(queryData, k, neighbors, distances);
+ else
+ allkann->Search(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 578c449..c113b10 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -34,9 +34,9 @@ namespace mlpack {
namespace neighbor {
/**
- * The LSHSearch class -- This class builds a hash on the reference set
- * and uses this hash to compute the distance-approximate nearest-neighbors
- * of the given queries.
+ * The LSHSearch class; this class builds a hash on the reference set and uses
+ * this hash to compute the distance-approximate nearest-neighbors of the given
+ * queries.
*
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
*/
@@ -49,8 +49,7 @@ class LSHSearch
* reference set with 2-stable distributions. See the individual functions
* performing the hashing for details on how the hashing is done.
*
- * @param referenceSet Set of reference points.
- * @param querySet Set of query points.
+ * @param referenceSet Set of reference points and the set of queries.
* @param numProj Number of projections in each hash table (anything between
* 10-50 might be a decent choice).
* @param numTables Total number of hash tables (anything between 10-20
@@ -66,7 +65,6 @@ class LSHSearch
* Default values are already provided here.
*/
LSHSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
const size_t numProj,
const size_t numTables,
const double hashWidth = 0.0,
@@ -74,31 +72,29 @@ class LSHSearch
const size_t bucketSize = 500);
/**
- * This function initializes the LSH class. It builds the hash on the
- * reference set with 2-stable distributions. See the individual functions
- * performing the hashing for details on how the hashing is done.
+ * Compute the nearest neighbors of the points in the given query set and
+ * store the output in the given matrices. The matrices will be set to the
+ * size of n columns by k rows, where n is the number of points in the query
+ * dataset and k is the number of neighbors being searched for.
*
- * @param referenceSet Set of reference points and the set of queries.
- * @param numProj Number of projections in each hash table (anything between
- * 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20
- * should suffice).
- * @param hashWidth The width of hash for every table. If 0 (the default) is
- * provided, then the hash width is automatically obtained by computing
- * the average pairwise distance of 25 pairs. This should be a reasonable
- * upper bound on the nearest-neighbor distance in general.
- * @param secondHashSize The size of the second hash table. This should be a
- * large prime number.
- * @param bucketSize The size of the bucket in the second hash table. This is
- * the maximum number of points that can be hashed into single bucket.
- * Default values are already provided here.
+ * @param querySet Set of query points.
+ * @param k Number of neighbors to search for.
+ * @param resultingNeighbors Matrix storing lists of neighbors for each query
+ * point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ * @param numTablesToSearch This parameter allows the user to have control
+ * over the number of hash tables to be searched. This allows
+ * the user to pick the number of tables it can afford for the time
+ * available without having to build hashing for every table size.
+ * By default, this is set to zero in which case all tables are
+ * considered.
*/
- LSHSearch(const arma::mat& referenceSet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidth = 0.0,
- const size_t secondHashSize = 99901,
- const size_t bucketSize = 500);
+ void Search(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const size_t numTablesToSearch = 0);
/**
* Compute the nearest neighbors and store the output in the given matrices.
@@ -153,28 +149,49 @@ class LSHSearch
* hash table and all the points (if any) in those buckets are collected as
* the potential neighbor candidates.
*
- * @param queryIndex The index of the query currently being processed.
+ * @param queryPoint The query point currently being processed.
* @param referenceIndices The list of neighbor candidates obtained from
* hashing the query into all the hash tables and eventually into
* multiple buckets of the second hash table.
*/
- void ReturnIndicesFromTable(const size_t queryIndex,
+ template<typename VecType>
+ void ReturnIndicesFromTable(const VecType& queryPoint,
arma::uvec& referenceIndices,
- size_t numTablesToSearch);
+ size_t numTablesToSearch) const;
/**
* This is a helper function that computes the distance of the query to the
- * neighbor candidates and appropriately stores the best 'k' candidates
+ * neighbor candidates and appropriately stores the best 'k' candidates. This
+ * is specific to the monochromatic search case, where the query set is the
+ * reference set.
*
- * @param distances Matrix holding output distances.
+ * @param queryIndex The index of the query in question
+ * @param referenceIndex The index of the neighbor candidate in question
* @param neighbors Matrix holding output neighbors.
+ * @param distances Matrix holding output distances.
+ */
+ void BaseCase(const size_t queryIndex,
+ const size_t referenceIndex,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const;
+
+ /**
+ * This is a helper function that computes the distance of the query to the
+ * neighbor candidates and appropriately stores the best 'k' candidates. This
+ * is specific to bichromatic search, where the query set is not the same as
+ * the reference set.
+ *
* @param queryIndex The index of the query in question
* @param referenceIndex The index of the neighbor candidate in question
+ * @param querySet Set of query points.
+ * @param neighbors Matrix holding output neighbors.
+ * @param distances Matrix holding output distances.
*/
- double BaseCase(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t referenceIndex);
+ void BaseCase(const size_t queryIndex,
+ const size_t referenceIndex,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const;
/**
* This is a helper function that efficiently inserts better neighbor
@@ -195,12 +212,10 @@ class LSHSearch
const size_t queryIndex,
const size_t pos,
const size_t neighbor,
- const double distance);
+ const double distance) const;
//! Reference dataset.
const arma::mat& referenceSet;
- //! Query dataset (may not be given).
- const arma::mat& querySet;
//! The number of projections.
const size_t numProj;
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index cd13557..5760f89 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -16,51 +16,12 @@ namespace neighbor {
template<typename SortPolicy>
LSHSearch<SortPolicy>::
LSHSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
const size_t numProj,
const size_t numTables,
const double hashWidthIn,
const size_t secondHashSize,
const size_t bucketSize) :
referenceSet(referenceSet),
- querySet(querySet),
- numProj(numProj),
- numTables(numTables),
- hashWidth(hashWidthIn),
- secondHashSize(secondHashSize),
- bucketSize(bucketSize),
- distanceEvaluations(0)
-{
- if (hashWidth == 0.0) // The user has not provided any value.
- {
- // Compute a heuristic hash width from the data.
- for (size_t i = 0; i < 25; i++)
- {
- size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
- size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
-
- hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
- referenceSet.unsafe_col(p1), referenceSet.unsafe_col(p2)));
- }
-
- hashWidth /= 25;
- }
-
- Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
-
- BuildHash();
-}
-
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::
-LSHSearch(const arma::mat& referenceSet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidthIn,
- const size_t secondHashSize,
- const size_t bucketSize) :
- referenceSet(referenceSet),
- querySet(referenceSet),
numProj(numProj),
numTables(numTables),
hashWidth(hashWidthIn),
@@ -94,7 +55,7 @@ void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
const size_t queryIndex,
const size_t pos,
const size_t neighbor,
- const double distance)
+ const double distance) const
{
// We only memmove() if there is actually a need to shift something.
if (pos < (distances.n_rows - 1))
@@ -113,20 +74,22 @@ void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
neighbors(pos, queryIndex) = neighbor;
}
+// Base case where the query set is the reference set. (So, we can't return
+// ourselves as the nearest neighbor.)
template<typename SortPolicy>
inline force_inline
-double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
- arma::Mat<size_t>& neighbors,
- const size_t queryIndex,
- const size_t referenceIndex)
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+ const size_t referenceIndex,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const
{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return 0.0;
+ // If the points are the same, we can't continue.
+ if (queryIndex == referenceIndex)
+ return;
const double distance = metric::EuclideanDistance::Evaluate(
- querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+ referenceSet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
@@ -139,15 +102,39 @@ double LSHSearch<SortPolicy>::BaseCase(arma::mat& distances,
if (insertPosition != (size_t() - 1))
InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
referenceIndex, distance);
+}
- return distance;
+// Base case for bichromatic search.
+template<typename SortPolicy>
+inline force_inline
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+ const size_t referenceIndex,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const
+{
+ const double distance = metric::EuclideanDistance::Evaluate(
+ querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+ distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+ referenceIndex, distance);
}
template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-ReturnIndicesFromTable(const size_t queryIndex,
- arma::uvec& referenceIndices,
- size_t numTablesToSearch)
+template<typename VecType>
+void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
+ const VecType& queryPoint,
+ arma::uvec& referenceIndices,
+ size_t numTablesToSearch) const
{
// Decide on the number of tables to look into.
if (numTablesToSearch == 0) // If no user input is given, search all.
@@ -166,10 +153,7 @@ ReturnIndicesFromTable(const size_t queryIndex,
// Compute the projection of the query in each table.
arma::mat allProjInTables(numProj, numTablesToSearch);
for (size_t i = 0; i < numTablesToSearch; i++)
- {
- allProjInTables.unsafe_col(i) = projections[i].t() *
- querySet.unsafe_col(queryIndex);
- }
+ allProjInTables.unsafe_col(i) = projections[i].t() * queryPoint;
allProjInTables += offsets.cols(0, numTablesToSearch - 1);
allProjInTables /= hashWidth;
@@ -206,7 +190,58 @@ ReturnIndicesFromTable(const size_t queryIndex,
referenceIndices = arma::find(refPointsConsidered > 0);
}
+// Search for nearest neighbors in a given query set.
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const size_t numTablesToSearch)
+{
+ // Ensure the dimensionality of the query set is correct.
+ if (querySet.n_rows != referenceSet.n_rows)
+ Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
+ << querySet.n_rows << ") is not equal to the dimensionality the model "
+ << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
+ // Set the size of the neighbor and distance matrices.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+ distances.fill(SortPolicy::WorstDistance());
+ resultingNeighbors.fill(referenceSet.n_cols);
+
+ size_t avgIndicesReturned = 0;
+
+ Timer::Start("computing_neighbors");
+
+ // Go through every query point sequentially.
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ // Hash every query into every hash table and eventually into the
+ // 'secondHashTable' to obtain the neighbor candidates.
+ arma::uvec refIndices;
+ ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch);
+
+ // An informative book-keeping for the number of neighbor candidates
+ // returned on average.
+ avgIndicesReturned += refIndices.n_elem;
+
+ // Sequentially go through all the candidates and save the best 'k'
+ // candidates.
+ for (size_t j = 0; j < refIndices.n_elem; j++)
+ BaseCase(i, (size_t) refIndices[j], querySet, resultingNeighbors,
+ distances);
+ }
+
+ Timer::Stop("computing_neighbors");
+
+ distanceEvaluations += avgIndicesReturned;
+ avgIndicesReturned /= querySet.n_cols;
+ Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
+ std::endl;
+}
+
+// Search for approximate neighbors of the reference set.
template<typename SortPolicy>
void LSHSearch<SortPolicy>::
Search(const size_t k,
@@ -214,9 +249,9 @@ Search(const size_t k,
arma::mat& distances,
const size_t numTablesToSearch)
{
- // Set the size of the neighbor and distance matrices.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
+ // This is monochromatic search; the query set is the reference set.
+ resultingNeighbors.set_size(k, referenceSet.n_cols);
+ distances.set_size(k, referenceSet.n_cols);
distances.fill(SortPolicy::WorstDistance());
resultingNeighbors.fill(referenceSet.n_cols);
@@ -225,12 +260,12 @@ Search(const size_t k,
Timer::Start("computing_neighbors");
// Go through every query point sequentially.
- for (size_t i = 0; i < querySet.n_cols; i++)
+ for (size_t i = 0; i < referenceSet.n_cols; i++)
{
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
- ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
+ ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch);
// An informative book-keeping for the number of neighbor candidates
// returned on average.
@@ -239,13 +274,13 @@ Search(const size_t k,
// Sequentially go through all the candidates and save the best 'k'
// candidates.
for (size_t j = 0; j < refIndices.n_elem; j++)
- BaseCase(distances, resultingNeighbors, i, (size_t) refIndices[j]);
+ BaseCase(i, (size_t) refIndices[j], resultingNeighbors, distances);
}
Timer::Stop("computing_neighbors");
distanceEvaluations += avgIndicesReturned;
- avgIndicesReturned /= querySet.n_cols;
+ avgIndicesReturned /= referenceSet.n_cols;
Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
std::endl;
}
@@ -400,9 +435,6 @@ std::string LSHSearch<SortPolicy>::ToString() const
convert << "LSHSearch [" << this << "]" << std::endl;
convert << " Reference Set: " << referenceSet.n_rows << "x" ;
convert << referenceSet.n_cols << std::endl;
- if (&referenceSet != &querySet)
- convert << " QuerySet: " << querySet.n_rows << "x" << querySet.n_cols
- << std::endl;
convert << " Number of Projections: " << numProj << std::endl;
convert << " Number of Tables: " << numTables << std::endl;
convert << " Hash Width: " << hashWidth << std::endl;
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 52afc25..341da69 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -58,7 +58,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
// projMat.randn(2, 3)
// COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
// COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
- LSHSearch<> lsh_test(rdata, qdata, 3, 2, hashWidth, 11, 3);
+ LSHSearch<> lsh_test(rdata, 3, 2, hashWidth, 11, 3);
// LSHSearch<> lsh_test(rdata, qdata, 3, 2, 0.0, 11, 3);
// Given this, the 'LSHSearch::bucketRowInHashTable' should be:
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
arma::Mat<size_t> neighbors;
arma::mat distances;
- lsh_test.Search(2, neighbors, distances);
+ lsh_test.Search(qdata, 2, neighbors, distances);
// The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
// should hash the query 0 into the following buckets:
More information about the mlpack-git
mailing list