[mlpack-git] master: Implements parallel query processing for LSH (72999dd)

gitdub at mlpack.org gitdub at mlpack.org
Fri Jul 8 14:36:38 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/34cf8d94f79c9a72ff4199676033b060cd039fcd...425324bf7fb7c86c85d10a909d8a59d4f69b7164

>---------------------------------------------------------------

commit 72999dd6730801f011e82d46656e5d33f0cbd20f
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Sun Jun 19 14:17:56 2016 +0300

    Implements parallel query processing for LSH


>---------------------------------------------------------------

72999dd6730801f011e82d46656e5d33f0cbd20f
 src/mlpack/methods/lsh/lsh_search.hpp      | 14 +++++++++++
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 33 ++++++++++++++++++++++---
 src/mlpack/tests/lsh_test.cpp              | 39 ++++++++++++++++++++++++++++++
 3 files changed, 82 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index ad285ea..9898d84 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -189,6 +189,7 @@ class LSHSearch
               arma::mat& distances,
               const size_t numTablesToSearch = 0);
 
+
   /**
    * Serialize the LSH model.
    *
@@ -236,6 +237,12 @@ class LSHSearch
   //! removed in mlpack 2.1.0!
   const arma::mat& Projection(size_t i) { return projections.slice(i); }
 
+  //! Set the maximum number of threads the object is allowed to use
+  void MaxThreads(size_t numThreads) { maxThreads = numThreads;}
+
+  //! Return the current maxumum threads the object is allowed to use
+  size_t MaxThreads(void) const { return maxThreads; }
+
  private:
   /**
    * This function takes a query and hashes it into each of the hash tables to
@@ -350,6 +357,13 @@ class LSHSearch
 
   //! The number of distance evaluations.
   size_t distanceEvaluations;
+
+  //! The maximum number of threads allowed.
+  size_t maxThreads;
+
+  //! The number of threads currently in use.
+  size_t numThreadsUsed;
+
 }; // class LSHSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b34d22c..bd98f64 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -9,6 +9,8 @@
 
 #include <mlpack/core.hpp>
 
+using std::cout; using std::endl; //TODO: remove
+
 namespace mlpack {
 namespace neighbor {
 
@@ -28,7 +30,9 @@ LSHSearch(const arma::mat& referenceSet,
   hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
   bucketSize(bucketSize),
-  distanceEvaluations(0)
+  distanceEvaluations(0),
+  maxThreads(omp_get_max_threads()),
+  numThreadsUsed(0)
 {
   // Pass work to training function.
   Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
@@ -50,7 +54,9 @@ LSHSearch(const arma::mat& referenceSet,
   hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
   bucketSize(bucketSize),
-  distanceEvaluations(0)
+  distanceEvaluations(0),
+  maxThreads(omp_get_max_threads()),
+  numThreadsUsed(0)
 {
   // Pass work to training function
   Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
@@ -67,7 +73,9 @@ LSHSearch<SortPolicy>::LSHSearch() :
     hashWidth(0),
     secondHashSize(99901),
     bucketSize(500),
-    distanceEvaluations(0)
+    distanceEvaluations(0),
+    maxThreads(omp_get_max_threads()),
+    numThreadsUsed(0)
 {
   // Nothing to do.
 }
@@ -522,13 +530,28 @@ Search(const size_t k,
   distances.fill(SortPolicy::WorstDistance());
   resultingNeighbors.fill(referenceSet->n_cols);
 
+
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
 
-  // Go through every query point sequentially.
+  // Parallelization allows us to process more than one query at a time. To
+  // control workload and thread access, we use numThreadsUsed and maxThreads to
+  // make sure we only use as many threads as the user specified.
+  #pragma omp parallel for \
+    if (numThreadsUsed <= maxThreads) \
+    num_threads (maxThreads-numThreadsUsed)\
+    shared(avgIndicesReturned, resultingNeighbors, distances) \
+    schedule(dynamic)
+  // Go through every query point.
   for (size_t i = 0; i < referenceSet->n_cols; i++)
   {
+    // Master thread updates the number of threads used
+    if (i == 0 && omp_get_thread_num() == 0)
+    {
+      numThreadsUsed+=omp_get_num_threads();
+      cout<<"Using "<<numThreadsUsed<<endl;
+    }
     // Hash every query into every hash table and eventually into the
     // 'secondHashTable' to obtain the neighbor candidates.
     arma::uvec refIndices;
@@ -536,6 +559,8 @@ Search(const size_t k,
 
     // An informative book-keeping for the number of neighbor candidates
     // returned on average.
+    // Make atomic to avoid race conditions when multiple threads are running
+    #pragma omp atomic
     avgIndicesReturned += refIndices.n_elem;
 
     // Sequentially go through all the candidates and save the best 'k'
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index da489ab..8ddb7d3 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -479,6 +479,45 @@ BOOST_AUTO_TEST_CASE(DeterministicNoMerge)
   }
 }
 
+/**
+ * Test: This test verifies that parallel query processing returns correct
+ * results.
+ */
+BOOST_AUTO_TEST_CASE(ParallelQueryTest)
+{
+  // kNN and LSH parameters (use LSH default parameters).
+  const int k = 4;
+  const int numTables = 16;
+  const int numProj = 3;
+
+  // Read iris training and testing data as reference and query sets.
+  const string trainSet = "iris_train.csv";
+  const string testSet = "iris_test.csv";
+  arma::mat rdata;
+  arma::mat qdata;
+  data::Load(trainSet, rdata, true);
+  data::Load(testSet, qdata, true);
+
+  // Where to store neighbors and distances
+  arma::Mat<size_t> sequentialNeighbors;
+  arma::Mat<size_t> parallelNeighbors;
+  arma::mat distances;
+
+  // Construct an LSH object. By default, it uses the maximum number of threads
+  LSHSearch<> lshTest(rdata, numProj, numTables); //default parameters
+  lshTest.Search(qdata, k, parallelNeighbors, distances);
+
+  // Now perform same search but with 1 thread
+  lshTest.MaxThreads(1);
+  lshTest.Search(qdata, k, sequentialNeighbors, distances);
+
+  // Require both have same results
+  double recall = ComputeRecall(sequentialNeighbors, parallelNeighbors);
+  BOOST_REQUIRE_EQUAL(recall, 1);
+
+}
+
+
 BOOST_AUTO_TEST_CASE(LSHTrainTest)
 {
   // This is a not very good test that simply checks that the re-trained LSH




More information about the mlpack-git mailing list