[mlpack-git] master: Adds two tests for LSHSearch (0031a82)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Mar 26 12:44:04 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5bc514c122d53590397fdfad42c7845d9ad91fa1...f0675d7789b69746f7c337c3ec4a778cef932924
>---------------------------------------------------------------
commit 0031a82e4025e0f2894d9edd30712d596081dbe1
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Sat Mar 26 18:44:04 2016 +0200
Adds two tests for LSHSearch
>---------------------------------------------------------------
0031a82e4025e0f2894d9edd30712d596081dbe1
src/mlpack/tests/lsh_test.cpp | 77 ++++++++++++++++++++++++++++++-------------
1 file changed, 55 insertions(+), 22 deletions(-)
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index e55b974..cd22633 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,45 +15,63 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
+double compute_recall(arma::Mat<size_t> LSHneighbors, arma::Mat<size_t> groundTruth){
+ const int n_queries = LSHneighbors.n_cols;
+ const int n_neigh = LSHneighbors.n_rows;
+
+ int recall = 0;
+ for (int q = 0; q < n_queries; ++q){
+ for (int n = 0; n < n_neigh; ++n){
+ recall+=(LSHneighbors(n,q)==groundTruth(n,q));
+ }
+ }
+ return static_cast<double>(recall)/
+ (static_cast<double>(n_queries*n_neigh));
+}
+
BOOST_AUTO_TEST_SUITE(LSHTest);
BOOST_AUTO_TEST_CASE(LSHSearchTest)
{
+
+ math::RandomSeed(time(0));
//kNN and LSH parameters
const int k = 4;
+ //const int numTables = 30;
const int numProj = 10;
const double hashWidth = 0;
const int secondHashSize = 99901;
const int bucketSize = 500;
+ //test parameters
+ const double epsilon = 0.1;
+
//read iris training and testing data as reference and query
- string iris_train="iris_train.csv";
- string iris_test="iris_test.csv";
+ const string iris_train="iris_train.csv";
+ const string iris_test="iris_test.csv";
arma::mat rdata;
arma::mat qdata;
data::Load(iris_train, rdata, true);
data::Load(iris_test, qdata, true);
- const int n_queries = qdata.n_cols;
-
-
- //Run classic knn on reference data
- AllkNN knn(rdata);
- arma::Mat<size_t> groundTruth;
- arma::mat groundDistances;
- knn.Search(qdata, k, groundTruth, groundDistances);
//Test: Run LSH with varying number of tables, keeping all other parameters
//constant. Compute the recall, i.e. the number of reported neighbors that
//are real neighbors of the query.
//LSH's property is that (with high probability), increasing the number of
- //tables will increase recall.
+ //tables will increase recall. Epsilon ensures that if noise lightly affects
+ //the projections, the test will not fail.
+ //Run classic knn on reference data
+ AllkNN knn(rdata);
+ arma::Mat<size_t> groundTruth;
+ arma::mat groundDistances;
+ knn.Search(qdata, k, groundTruth, groundDistances);
- //Run LSH with varying number of tables and compute recall
- const int L[] = {1, 4, 16, 32, 64, 128}; //number of tables
+ //Run LSH for different number of tables
+ const int L[] = {1, 8, 16, 32, 64, 128}; //number of tables
const int Lsize = 6;
- int recall_L[Lsize] = {0}; //recall of each LSH run
+ double recall_L[Lsize] = {0.0}; //recall of each LSH run
for (int l=0; l < Lsize; ++l){
//run LSH with only numTables varying (other values default)
@@ -64,17 +82,32 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
//compute recall for each query
- for (int q=0; q < n_queries; ++q){
- for (int neigh = 0; neigh < k; ++neigh){
- if (LSHneighbors(neigh, q) == groundTruth(neigh, q))
- recall_L[l]++;
- }
- }
+ recall_L[l] = compute_recall(LSHneighbors, groundTruth);
- if (l > 0)
- BOOST_CHECK(recall_L[l] >= recall_L[l-1]);
+ if (l > 0){
+ BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
+ }
}
+ //Test: Run a very expensive LSH search, with a large number of hash tables
+ //and a large hash width. This run should return an acceptable recall. We set
+ //the bar very low (recall >= 50%) to make sure that a test fail means bad
+ //implementation.
+
+ const int H = 10000; //first-level hash width
+ const int K = 128; //projections per table
+ const int T = 128; //number of tables
+ const double recall_thresh_t2 = 0.5;
+
+ LSHSearch<> lsh_test(rdata, K, T, H, secondHashSize, bucketSize);
+ arma::Mat<size_t> LSHneighbors;
+ arma::mat LSHdistances;
+ lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+
+ const double recall_t2 = compute_recall(LSHneighbors, groundTruth);
+
+ BOOST_CHECK(recall_t2 >= recall_thresh_t2);
+
}
BOOST_AUTO_TEST_CASE(LSHTrainTest)
More information about the mlpack-git
mailing list