[mlpack-git] master: Implements parallel query processing for LSH (72999dd)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Jul 8 14:36:38 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/34cf8d94f79c9a72ff4199676033b060cd039fcd...425324bf7fb7c86c85d10a909d8a59d4f69b7164
>---------------------------------------------------------------
commit 72999dd6730801f011e82d46656e5d33f0cbd20f
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Sun Jun 19 14:17:56 2016 +0300
Implements parallel query processing for LSH
>---------------------------------------------------------------
72999dd6730801f011e82d46656e5d33f0cbd20f
src/mlpack/methods/lsh/lsh_search.hpp | 14 +++++++++++
src/mlpack/methods/lsh/lsh_search_impl.hpp | 33 ++++++++++++++++++++++---
src/mlpack/tests/lsh_test.cpp | 39 ++++++++++++++++++++++++++++++
3 files changed, 82 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index ad285ea..9898d84 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -189,6 +189,7 @@ class LSHSearch
arma::mat& distances,
const size_t numTablesToSearch = 0);
+
/**
* Serialize the LSH model.
*
@@ -236,6 +237,12 @@ class LSHSearch
//! removed in mlpack 2.1.0!
const arma::mat& Projection(size_t i) { return projections.slice(i); }
+ //! Set the maximum number of threads the object is allowed to use
+ void MaxThreads(size_t numThreads) { maxThreads = numThreads;}
+
+ //! Return the current maxumum threads the object is allowed to use
+ size_t MaxThreads(void) const { return maxThreads; }
+
private:
/**
* This function takes a query and hashes it into each of the hash tables to
@@ -350,6 +357,13 @@ class LSHSearch
//! The number of distance evaluations.
size_t distanceEvaluations;
+
+ //! The maximum number of threads allowed.
+ size_t maxThreads;
+
+ //! The number of threads currently in use.
+ size_t numThreadsUsed;
+
}; // class LSHSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b34d22c..bd98f64 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -9,6 +9,8 @@
#include <mlpack/core.hpp>
+using std::cout; using std::endl; //TODO: remove
+
namespace mlpack {
namespace neighbor {
@@ -28,7 +30,9 @@ LSHSearch(const arma::mat& referenceSet,
hashWidth(hashWidthIn),
secondHashSize(secondHashSize),
bucketSize(bucketSize),
- distanceEvaluations(0)
+ distanceEvaluations(0),
+ maxThreads(omp_get_max_threads()),
+ numThreadsUsed(0)
{
// Pass work to training function.
Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
@@ -50,7 +54,9 @@ LSHSearch(const arma::mat& referenceSet,
hashWidth(hashWidthIn),
secondHashSize(secondHashSize),
bucketSize(bucketSize),
- distanceEvaluations(0)
+ distanceEvaluations(0),
+ maxThreads(omp_get_max_threads()),
+ numThreadsUsed(0)
{
// Pass work to training function
Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
@@ -67,7 +73,9 @@ LSHSearch<SortPolicy>::LSHSearch() :
hashWidth(0),
secondHashSize(99901),
bucketSize(500),
- distanceEvaluations(0)
+ distanceEvaluations(0),
+ maxThreads(omp_get_max_threads()),
+ numThreadsUsed(0)
{
// Nothing to do.
}
@@ -522,13 +530,28 @@ Search(const size_t k,
distances.fill(SortPolicy::WorstDistance());
resultingNeighbors.fill(referenceSet->n_cols);
+
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
- // Go through every query point sequentially.
+ // Parallelization allows us to process more than one query at a time. To
+ // control workload and thread access, we use numThreadsUsed and maxThreads to
+ // make sure we only use as many threads as the user specified.
+ #pragma omp parallel for \
+ if (numThreadsUsed <= maxThreads) \
+ num_threads (maxThreads-numThreadsUsed)\
+ shared(avgIndicesReturned, resultingNeighbors, distances) \
+ schedule(dynamic)
+ // Go through every query point.
for (size_t i = 0; i < referenceSet->n_cols; i++)
{
+ // Master thread updates the number of threads used
+ if (i == 0 && omp_get_thread_num() == 0)
+ {
+ numThreadsUsed+=omp_get_num_threads();
+ cout<<"Using "<<numThreadsUsed<<endl;
+ }
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
@@ -536,6 +559,8 @@ Search(const size_t k,
// An informative book-keeping for the number of neighbor candidates
// returned on average.
+ // Make atomic to avoid race conditions when multiple threads are running
+ #pragma omp atomic
avgIndicesReturned += refIndices.n_elem;
// Sequentially go through all the candidates and save the best 'k'
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index da489ab..8ddb7d3 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -479,6 +479,45 @@ BOOST_AUTO_TEST_CASE(DeterministicNoMerge)
}
}
+/**
+ * Test: This test verifies that parallel query processing returns correct
+ * results.
+ */
+BOOST_AUTO_TEST_CASE(ParallelQueryTest)
+{
+ // kNN and LSH parameters (use LSH default parameters).
+ const int k = 4;
+ const int numTables = 16;
+ const int numProj = 3;
+
+ // Read iris training and testing data as reference and query sets.
+ const string trainSet = "iris_train.csv";
+ const string testSet = "iris_test.csv";
+ arma::mat rdata;
+ arma::mat qdata;
+ data::Load(trainSet, rdata, true);
+ data::Load(testSet, qdata, true);
+
+ // Where to store neighbors and distances
+ arma::Mat<size_t> sequentialNeighbors;
+ arma::Mat<size_t> parallelNeighbors;
+ arma::mat distances;
+
+ // Construct an LSH object. By default, it uses the maximum number of threads
+ LSHSearch<> lshTest(rdata, numProj, numTables); //default parameters
+ lshTest.Search(qdata, k, parallelNeighbors, distances);
+
+ // Now perform same search but with 1 thread
+ lshTest.MaxThreads(1);
+ lshTest.Search(qdata, k, sequentialNeighbors, distances);
+
+ // Require both have same results
+ double recall = ComputeRecall(sequentialNeighbors, parallelNeighbors);
+ BOOST_REQUIRE_EQUAL(recall, 1);
+
+}
+
+
BOOST_AUTO_TEST_CASE(LSHTrainTest)
{
// This is a not very good test that simply checks that the re-trained LSH
More information about the mlpack-git
mailing list