[mlpack-git] master: Adds two tests for LSHSearch (0031a82)

gitdub at mlpack.org gitdub at mlpack.org
Sat Mar 26 12:44:04 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5bc514c122d53590397fdfad42c7845d9ad91fa1...f0675d7789b69746f7c337c3ec4a778cef932924

>---------------------------------------------------------------

commit 0031a82e4025e0f2894d9edd30712d596081dbe1
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Sat Mar 26 18:44:04 2016 +0200

    Adds two tests for LSHSearch


>---------------------------------------------------------------

0031a82e4025e0f2894d9edd30712d596081dbe1
 src/mlpack/tests/lsh_test.cpp | 77 ++++++++++++++++++++++++++++++-------------
 1 file changed, 55 insertions(+), 22 deletions(-)

diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index e55b974..cd22633 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,45 +15,63 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 
+double compute_recall(arma::Mat<size_t> LSHneighbors, arma::Mat<size_t> groundTruth){
+  const int n_queries = LSHneighbors.n_cols;
+  const int n_neigh = LSHneighbors.n_rows;
+
+  int recall = 0;
+  for (int q = 0; q < n_queries; ++q){
+      for (int n = 0; n < n_neigh; ++n){
+        recall+=(LSHneighbors(n,q)==groundTruth(n,q));
+      }
+  }
+  return static_cast<double>(recall)/
+      (static_cast<double>(n_queries*n_neigh));
+}
+
 BOOST_AUTO_TEST_SUITE(LSHTest);
 
 BOOST_AUTO_TEST_CASE(LSHSearchTest)
 {
+
+  math::RandomSeed(time(0));
   //kNN and LSH parameters
   const int k = 4;
+  //const int numTables = 30;
   const int numProj = 10;
   const double hashWidth = 0;
   const int secondHashSize = 99901;
   const int bucketSize = 500;
   
+  //test parameters
+  const double epsilon = 0.1;
+
   //read iris training and testing data as reference and query
-  string iris_train="iris_train.csv";
-  string iris_test="iris_test.csv";
+  const string iris_train="iris_train.csv";
+  const string iris_test="iris_test.csv";
   arma::mat rdata;
   arma::mat qdata;
   data::Load(iris_train, rdata, true);
   data::Load(iris_test, qdata, true);
-  const int n_queries = qdata.n_cols;
-
-
-  //Run classic knn on reference data
-  AllkNN knn(rdata);
-  arma::Mat<size_t> groundTruth;
-  arma::mat groundDistances;
-  knn.Search(qdata, k, groundTruth, groundDistances);
 
 
   //Test: Run LSH with varying number of tables, keeping all other parameters 
   //constant. Compute the recall, i.e. the number of reported neighbors that
   //are real neighbors of the query.
   //LSH's property is that (with high probability), increasing the number of
-  //tables will increase recall.
+  //tables will increase recall. Epsilon ensures that if noise lightly affects
+  //the projections, the test will not fail.
 
+  //Run classic knn on reference data
+  AllkNN knn(rdata);
+  arma::Mat<size_t> groundTruth;
+  arma::mat groundDistances;
+  knn.Search(qdata, k, groundTruth, groundDistances);
 
-  //Run LSH with varying number of tables and compute recall
-  const int L[] = {1, 4, 16, 32, 64, 128}; //number of tables
+  //Run LSH for different number of tables
+  const int L[] = {1, 8, 16, 32, 64, 128}; //number of tables
   const int Lsize = 6;
-  int recall_L[Lsize] = {0}; //recall of each LSH run
+  double recall_L[Lsize] = {0.0}; //recall of each LSH run
 
   for (int l=0; l < Lsize; ++l){
       //run LSH with only numTables varying (other values default)
@@ -64,17 +82,32 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
       lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
 
       //compute recall for each query
-      for (int q=0; q < n_queries; ++q){
-          for (int neigh = 0; neigh < k; ++neigh){
-              if (LSHneighbors(neigh, q) == groundTruth(neigh, q))
-                  recall_L[l]++;
-          }
-      }
+      recall_L[l] = compute_recall(LSHneighbors, groundTruth);
 
-      if (l > 0)
-          BOOST_CHECK(recall_L[l] >= recall_L[l-1]);
+      if (l > 0){
+          BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
+      }
   }
 
+  //Test: Run a very expensive LSH search, with a large number of hash tables
+  //and a large hash width. This run should return an acceptable recall. We set
+  //the bar very low (recall >= 50%) to make sure that a test fail means bad
+  //implementation.
+  
+  const int H = 10000; //first-level hash width
+  const int K = 128; //projections per table
+  const int T = 128; //number of tables
+  const double recall_thresh_t2 = 0.5;
+
+  LSHSearch<> lsh_test(rdata, K, T, H, secondHashSize, bucketSize);
+  arma::Mat<size_t> LSHneighbors;
+  arma::mat LSHdistances;
+  lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+  
+  const double recall_t2 = compute_recall(LSHneighbors, groundTruth);
+
+  BOOST_CHECK(recall_t2 >= recall_thresh_t2);
+
 }
 
 BOOST_AUTO_TEST_CASE(LSHTrainTest)




More information about the mlpack-git mailing list