[mlpack-git] master: Add testing for LSHSearch and refactors code (799cafe)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Mar 31 08:22:23 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5bc514c122d53590397fdfad42c7845d9ad91fa1...f0675d7789b69746f7c337c3ec4a778cef932924
>---------------------------------------------------------------
commit 799cafee84731ce257626b3c769412927eafc737
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Thu Mar 31 15:22:23 2016 +0300
Add testing for LSHSearch and refactors code
>---------------------------------------------------------------
799cafee84731ce257626b3c769412927eafc737
src/mlpack/tests/lsh_test.cpp | 76 +++++++++++++++++++++++--------------------
1 file changed, 40 insertions(+), 36 deletions(-)
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index cd22633..9561dda 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,17 +15,21 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
-double compute_recall(arma::Mat<size_t> LSHneighbors, arma::Mat<size_t> groundTruth){
+double compute_recall(arma::Mat<size_t> LSHneighbors,
+ arma::Mat<size_t> groundTruth)
+{
const int n_queries = LSHneighbors.n_cols;
const int n_neigh = LSHneighbors.n_rows;
- int recall = 0;
- for (int q = 0; q < n_queries; ++q){
- for (int n = 0; n < n_neigh; ++n){
- recall+=(LSHneighbors(n,q)==groundTruth(n,q));
- }
+ int found_same = 0;
+ for (int q = 0; q < n_queries; ++q)
+ {
+ for (int n = 0; n < n_neigh; ++n)
+ {
+ found_same+=(LSHneighbors(n,q)==groundTruth(n,q));
+ }
}
- return static_cast<double>(recall)/
+ return static_cast<double>(found_same)/
(static_cast<double>(n_queries*n_neigh));
}
@@ -35,7 +39,7 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
{
math::RandomSeed(time(0));
- //kNN and LSH parameters
+ //kNN and LSH parameters (use LSH default parameters)
const int k = 4;
//const int numTables = 30;
const int numProj = 10;
@@ -47,13 +51,18 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
const double epsilon = 0.1;
//read iris training and testing data as reference and query
- const string iris_train="iris_train.csv";
- const string iris_test="iris_test.csv";
+ const string data_train="iris_train.csv";
+ const string data_test="iris_test.csv";
arma::mat rdata;
arma::mat qdata;
- data::Load(iris_train, rdata, true);
- data::Load(iris_test, qdata, true);
+ data::Load(data_train, rdata, true);
+ data::Load(data_test, qdata, true);
+ //Run classic knn on reference data
+ AllkNN knn(rdata);
+ arma::Mat<size_t> groundTruth;
+ arma::mat groundDistances;
+ knn.Search(qdata, k, groundTruth, groundDistances);
//Test: Run LSH with varying number of tables, keeping all other parameters
//constant. Compute the recall, i.e. the number of reported neighbors that
@@ -62,31 +71,26 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
//tables will increase recall. Epsilon ensures that if noise lightly affects
//the projections, the test will not fail.
- //Run classic knn on reference data
- AllkNN knn(rdata);
- arma::Mat<size_t> groundTruth;
- arma::mat groundDistances;
- knn.Search(qdata, k, groundTruth, groundDistances);
-
//Run LSH for different number of tables
- const int L[] = {1, 8, 16, 32, 64, 128}; //number of tables
+ const int L_table[] = {1, 8, 16, 32, 64, 128}; //number of tables
const int Lsize = 6;
double recall_L[Lsize] = {0.0}; //recall of each LSH run
- for (int l=0; l < Lsize; ++l){
- //run LSH with only numTables varying (other values default)
- LSHSearch<> lsh_test(rdata, numProj, L[l],
- hashWidth, secondHashSize, bucketSize);
- arma::Mat<size_t> LSHneighbors;
- arma::mat LSHdistances;
- lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+ for (int l=0; l < Lsize; ++l)
+ {
+ //run LSH with only numTables varying (other values default)
+ LSHSearch<> lsh_test1(rdata, numProj, L_table[l],
+ hashWidth, secondHashSize, bucketSize);
+ arma::Mat<size_t> LSHneighbors;
+ arma::mat LSHdistances;
+ lsh_test1.Search(qdata, k, LSHneighbors, LSHdistances);
+
+ //compute recall for each query
+ recall_L[l] = compute_recall(LSHneighbors, groundTruth);
- //compute recall for each query
- recall_L[l] = compute_recall(LSHneighbors, groundTruth);
+ if (l > 0)
+ BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
- if (l > 0){
- BOOST_CHECK(recall_L[l] >= recall_L[l-1]-epsilon);
- }
}
//Test: Run a very expensive LSH search, with a large number of hash tables
@@ -94,15 +98,15 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
//the bar very low (recall >= 50%) to make sure that a test fail means bad
//implementation.
- const int H = 10000; //first-level hash width
- const int K = 128; //projections per table
- const int T = 128; //number of tables
+ const int Ht2 = 10000; //first-level hash width
+ const int Kt2 = 128; //projections per table
+ const int Tt2 = 128; //number of tables
const double recall_thresh_t2 = 0.5;
- LSHSearch<> lsh_test(rdata, K, T, H, secondHashSize, bucketSize);
+ LSHSearch<> lsh_test2(rdata, Kt2, Tt2, Ht2, secondHashSize, bucketSize);
arma::Mat<size_t> LSHneighbors;
arma::mat LSHdistances;
- lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+ lsh_test2.Search(qdata, k, LSHneighbors, LSHdistances);
const double recall_t2 = compute_recall(LSHneighbors, groundTruth);
More information about the mlpack-git
mailing list