[mlpack-git] master: Add testing for LSHSearch and refactors code (799cafe)

gitdub at mlpack.org gitdub at mlpack.org
Thu Mar 31 08:22:23 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5bc514c122d53590397fdfad42c7845d9ad91fa1...f0675d7789b69746f7c337c3ec4a778cef932924

>---------------------------------------------------------------

commit 799cafee84731ce257626b3c769412927eafc737
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Thu Mar 31 15:22:23 2016 +0300

    Add testing for LSHSearch and refactors code


>---------------------------------------------------------------

799cafee84731ce257626b3c769412927eafc737
 src/mlpack/tests/lsh_test.cpp | 76 +++++++++++++++++++++++--------------------
 1 file changed, 40 insertions(+), 36 deletions(-)

diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index cd22633..9561dda 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,17 +15,21 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 
-double compute_recall(arma::Mat<size_t> LSHneighbors, arma::Mat<size_t> groundTruth){
+double compute_recall(arma::Mat<size_t> LSHneighbors, 
+        arma::Mat<size_t> groundTruth)
+{
   const int n_queries = LSHneighbors.n_cols;
   const int n_neigh = LSHneighbors.n_rows;
 
-  int recall = 0;
-  for (int q = 0; q < n_queries; ++q){
-      for (int n = 0; n < n_neigh; ++n){
-        recall+=(LSHneighbors(n,q)==groundTruth(n,q));
-      }
+  int found_same = 0;
+  for (int q = 0; q < n_queries; ++q)
+  {
+    for (int n = 0; n < n_neigh; ++n)
+    {
+      found_same+=(LSHneighbors(n,q)==groundTruth(n,q));
+    }
   }
-  return static_cast<double>(recall)/
+  return static_cast<double>(found_same)/
       (static_cast<double>(n_queries*n_neigh));
 }
 
@@ -35,7 +39,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
 {
 
   math::RandomSeed(time(0));
-  //kNN and LSH parameters
+  //kNN and LSH parameters (use LSH default parameters)
   const int k = 4;
   //const int numTables = 30;
   const int numProj = 10;
@@ -47,13 +51,18 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   const double epsilon = 0.1;
 
   //read iris training and testing data as reference and query
-  const string iris_train="iris_train.csv";
-  const string iris_test="iris_test.csv";
+  const string data_train="iris_train.csv";
+  const string data_test="iris_test.csv";
   arma::mat rdata;
   arma::mat qdata;
-  data::Load(iris_train, rdata, true);
-  data::Load(iris_test, qdata, true);
+  data::Load(data_train, rdata, true);
+  data::Load(data_test, qdata, true);
 
+  //Run classic knn on reference data
+  AllkNN knn(rdata);
+  arma::Mat<size_t> groundTruth;
+  arma::mat groundDistances;
+  knn.Search(qdata, k, groundTruth, groundDistances);
 
   //Test: Run LSH with varying number of tables, keeping all other parameters 
   //constant. Compute the recall, i.e. the number of reported neighbors that
@@ -62,31 +71,26 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   //tables will increase recall. Epsilon ensures that if noise lightly affects
   //the projections, the test will not fail.
 
-  //Run classic knn on reference data
-  AllkNN knn(rdata);
-  arma::Mat<size_t> groundTruth;
-  arma::mat groundDistances;
-  knn.Search(qdata, k, groundTruth, groundDistances);
-
   //Run LSH for different number of tables
-  const int L[] = {1, 8, 16, 32, 64, 128}; //number of tables
+  const int L_table[] = {1, 8, 16, 32, 64, 128}; //number of tables
   const int Lsize = 6;
   double recall_L[Lsize] = {0.0}; //recall of each LSH run
 
-  for (int l=0; l < Lsize; ++l){
-      //run LSH with only numTables varying (other values default)
-      LSHSearch<> lsh_test(rdata, numProj, L[l], 
-              hashWidth, secondHashSize, bucketSize);
-      arma::Mat<size_t> LSHneighbors;
-      arma::mat LSHdistances;
-      lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+  for (int l=0; l < Lsize; ++l)
+  {
+    //run LSH with only numTables varying (other values default)
+    LSHSearch<> lsh_test1(rdata, numProj, L_table[l], 
+            hashWidth, secondHashSize, bucketSize);
+    arma::Mat<size_t> LSHneighbors;
+    arma::mat LSHdistances;
+    lsh_test1.Search(qdata, k, LSHneighbors, LSHdistances);
+
+    //compute recall for each query
+    recall_L[l] = compute_recall(LSHneighbors, groundTruth);
 
-      //compute recall for each query
-      recall_L[l] = compute_recall(LSHneighbors, groundTruth);
+    if (l > 0)
+        BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
     
-      if (l > 0){
-          BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
-      }
   }
 
   //Test: Run a very expensive LSH search, with a large number of hash tables
@@ -94,15 +98,15 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   //the bar very low (recall >= 50%) to make sure that a test fail means bad
   //implementation.
   
-  const int H = 10000; //first-level hash width
-  const int K = 128; //projections per table
-  const int T = 128; //number of tables
+  const int Ht2 = 10000; //first-level hash width
+  const int Kt2 = 128; //projections per table
+  const int Tt2 = 128; //number of tables
   const double recall_thresh_t2 = 0.5;
 
-  LSHSearch<> lsh_test(rdata, K, T, H, secondHashSize, bucketSize);
+  LSHSearch<> lsh_test2(rdata, Kt2, Tt2, Ht2, secondHashSize, bucketSize);
   arma::Mat<size_t> LSHneighbors;
   arma::mat LSHdistances;
-  lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+  lsh_test2.Search(qdata, k, LSHneighbors, LSHdistances);
   
   const double recall_t2 = compute_recall(LSHneighbors, groundTruth);
 




More information about the mlpack-git mailing list