[mlpack-git] master: lsh_test tests if recall increases with L (3253027)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Apr 6 17:57:53 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5bc514c122d53590397fdfad42c7845d9ad91fa1...f0675d7789b69746f7c337c3ec4a778cef932924
>---------------------------------------------------------------
commit 3253027d7245585e95109a5fbed9c7efdd4160b4
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Tue Mar 22 18:26:56 2016 +0200
lsh_test tests if recall increases with L
>---------------------------------------------------------------
3253027d7245585e95109a5fbed9c7efdd4160b4
src/mlpack/tests/lsh_test.cpp | 143 +++++++++++++++++-------------------------
1 file changed, 56 insertions(+), 87 deletions(-)
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 70c132d..e55b974 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -9,6 +9,7 @@
#include "old_boost_test_definitions.hpp"
#include <mlpack/methods/lsh/lsh_search.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
using namespace std;
using namespace mlpack;
@@ -18,94 +19,62 @@ BOOST_AUTO_TEST_SUITE(LSHTest);
BOOST_AUTO_TEST_CASE(LSHSearchTest)
{
- // Force to specific random seed for these results.
- math::RandomSeed(0);
-
- // Precomputed hash width value.
- const double hashWidth = 4.24777;
-
- arma::mat rdata(2, 10);
- rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
- 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
-
- arma::mat qdata(2, 3);
- qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
-
- // INPUT TO LSH:
- // Number of points: 10
- // Number of dimensions: 2
- // Number of projections per table: 'numProj' = 3
- // Number of hash tables: 'numTables' = 2
- // hashWidth (computed): 'hashWidth' = 4.24777
- // Second hash size: 'secondHashSize' = 11
- // Size of the bucket: 'bucketSize' = 3
-
- // Things obtained by random sampling listed in the sequences
- // as they will be obtained in the 'LSHSearch::BuildHash()' private function
- // in 'LSHSearch' class.
- //
- // 1. The weights of the second hash obtained as:
- // secondHashWeights = arma::floor(arma::randu(3) * 11.0);
- // COR.SOL.: secondHashWeights = [9, 4, 8];
- //
- // 2. The offsets for all the 3 projections in each of the 2 tables:
- // offsets.randu(3, 2)
- // COR.SOL.: [0.7984 0.3352; 0.9116 0.7682; 0.1976 0.2778]
- // offsets *= hashWidth
- // COR.SOL.: [3.3916 1.4240; 3.8725 3.2633; 0.8392 1.1799]
- //
- // 3. The (2 x 3) projection matrices for the 2 tables:
- // projMat.randn(2, 3)
- // COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
- // COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
- LSHSearch<> lsh_test(rdata, 3, 2, hashWidth, 11, 3);
-// LSHSearch<> lsh_test(rdata, qdata, 3, 2, 0.0, 11, 3);
-
- // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
- // COR.SOL.: [2 11 4 7 6 3 11 0 5 1 8]
- //
- // The 'LSHSearch::bucketContentSize' should be:
- // COR.SOL.: [2 0 1 1 3 1 0 3 3 3 1]
- //
- // The final hash table 'LSHSearch::secondHashTable' should be
- // of size (3 x 9) with the following content:
- // COR.SOL.:
- // [0 2 4; 1 7 8; 3 9 10; 5 10 10; 6 10 10; 0 5 6; 1 2 8; 3 10 10; 4 10 10]
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- lsh_test.Search(qdata, 2, neighbors, distances);
-
- // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
- // should hash the query 0 into the following buckets:
- // COR.SOL.: Table 1 Bucket 7, Table 2 Bucket 0, refInds = [0 2 3 4 9]
- //
- // The private function 'LSHSearch::ReturnIndicesFromTable(1, refInds)'
- // should hash the query 1 into the following buckets:
- // COR.SOL.: Table 1 Bucket 9, Table 2 Bucket 4, refInds = [1 2 7 8]
- //
- // The private function 'LSHSearch::ReturnIndicesFromTable(2, refInds)'
- // should hash the query 2 into the following buckets:
- // COR.SOL.: Table 1 Bucket 0, Table 2 Bucket 7, refInds = [0 2 3 4 9]
-
- // After search
- // COR.SOL.: 'neighbors' = [2 1 9; 3 8 2]
- // COR.SOL.: 'distances' = [2 0 2; 4 2 16]
-
- arma::Mat<size_t> true_neighbors(2, 3);
- true_neighbors << 2 << 1 << 9 << arma::endr << 3 << 8 << 2 << arma::endr;
- arma::mat true_distances(2, 3);
- true_distances << 2 << 0 << 2 << arma::endr << 4 << 2 << 16 << arma::endr;
-
- for (size_t i = 0; i < 3; i++)
- {
- for (size_t j = 0; j < 2; j++)
- {
-// BOOST_REQUIRE_EQUAL(neighbors(j, i), true_neighbors(j, i));
-// BOOST_REQUIRE_CLOSE(distances(j, i), true_distances(j, i), 1e-5);
- }
+ //kNN and LSH parameters
+ const int k = 4;
+ const int numProj = 10;
+ const double hashWidth = 0;
+ const int secondHashSize = 99901;
+ const int bucketSize = 500;
+
+ //read iris training and testing data as reference and query
+ string iris_train="iris_train.csv";
+ string iris_test="iris_test.csv";
+ arma::mat rdata;
+ arma::mat qdata;
+ data::Load(iris_train, rdata, true);
+ data::Load(iris_test, qdata, true);
+ const int n_queries = qdata.n_cols;
+
+
+ //Run classic knn on reference data
+ AllkNN knn(rdata);
+ arma::Mat<size_t> groundTruth;
+ arma::mat groundDistances;
+ knn.Search(qdata, k, groundTruth, groundDistances);
+
+
+ //Test: Run LSH with varying number of tables, keeping all other parameters
+ //constant. Compute the recall, i.e. the number of reported neighbors that
+ //are real neighbors of the query.
+ //LSH's property is that (with high probability), increasing the number of
+ //tables will increase recall.
+
+
+ //Run LSH with varying number of tables and compute recall
+ const int L[] = {1, 4, 16, 32, 64, 128}; //number of tables
+ const int Lsize = 6;
+ int recall_L[Lsize] = {0}; //recall of each LSH run
+
+ for (int l=0; l < Lsize; ++l){
+ //run LSH with only numTables varying (other values default)
+ LSHSearch<> lsh_test(rdata, numProj, L[l],
+ hashWidth, secondHashSize, bucketSize);
+ arma::Mat<size_t> LSHneighbors;
+ arma::mat LSHdistances;
+ lsh_test.Search(qdata, k, LSHneighbors, LSHdistances);
+
+ //compute recall for each query
+ for (int q=0; q < n_queries; ++q){
+ for (int neigh = 0; neigh < k; ++neigh){
+ if (LSHneighbors(neigh, q) == groundTruth(neigh, q))
+ recall_L[l]++;
+ }
+ }
+
+ if (l > 0)
+ BOOST_CHECK(recall_L[l] >= recall_L[l-1]);
}
+
}
BOOST_AUTO_TEST_CASE(LSHTrainTest)
More information about the mlpack-git
mailing list