[mlpack-git] master: Adds a first multiprobe test (8bc5ced)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 30 15:11:52 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eaa7182ebed8cce3fd6191dc1f8170546ea297da...812048c7c6bee0b6c8d936677f23bbb5930c6cfc

>---------------------------------------------------------------

commit 8bc5ced067b7fb43b7a0bf71dcb6c1303f6c4780
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Thu Jun 9 10:09:33 2016 +0300

    Adds a first multiprobe test


>---------------------------------------------------------------

8bc5ced067b7fb43b7a0bf71dcb6c1303f6c4780
 src/mlpack/tests/lsh_test.cpp | 83 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 83 insertions(+)

diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 85881b7..bbfe52a 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -493,6 +493,89 @@ BOOST_AUTO_TEST_CASE(DeterministicNoMerge)
   }
 
 }
+
+/**
+ * Test: Create an LSHSearch object and use an increasing number of probes to
+ * search for points. Require that recall for the same object doesn't decrease
+ * with increasing number of probes. Also require that at least a few times
+ * there's some increase in recall.
+ */
+BOOST_AUTO_TEST_CASE(MultiprobeTest)
+{
+  const double epsilonIncrease = 0.05;
+  const size_t repetitions = 5; // train five objects
+
+  const size_t probeTrials = 5;
+  const size_t numProbes[probeTrials] = {0, 1, 2, 3, 4};
+
+
+  /// algorithm parameters
+  const int k = 4;
+  const int numTables = 16;
+  const int numProj = 3;
+  const double hashWidth = 0;
+  const int secondHashSize = 99901;
+  const int bucketSize = 500;
+
+  const string trainSet = "iris_train.csv";
+  const string testSet = "iris_test.csv";
+  arma::mat rdata;
+  arma::mat qdata;
+  data::Load(trainSet, rdata, true);
+  data::Load(testSet, qdata, true);
+
+  // Run classic knn on reference set
+  KNN knn(rdata);
+  arma::Mat<size_t> groundTruth;
+  arma::mat groundDistances;
+  knn.Search(qdata, k, groundTruth, groundDistances);
+  
+  bool foundIncrease = 0;
+
+  for (size_t rep; rep < repetitions; ++rep)
+  {
+    // train a model
+    LSHSearch<> multiprobeTest(
+        rdata,
+        numProj,
+        numTables,
+        hashWidth,
+        secondHashSize,
+        bucketSize);
+
+    double prevRecall = 0;
+    // search with varying number of probes
+    for (size_t p = 0; p < probeTrials; ++p)
+    {
+      arma::Mat<size_t> lshNeighbors;
+      arma::mat lshDistances; //move outside of loop for speed?
+      
+      multiprobeTest.Search(
+          qdata, 
+          k, 
+          lshNeighbors, 
+          lshDistances, 
+          0, 
+          numProbes[p]);
+
+      // compute recall
+      double recall = ComputeRecall(lshNeighbors, groundTruth); //TODO: change to LSHSearch::ComputeRecall??
+      if (p > 0)
+      {
+        // more probes should at the very least not lower recall...
+        BOOST_REQUIRE(recall >= prevRecall);
+
+        // ... and should ideally increase it a bit
+        if (recall > prevRecall + epsilonIncrease)
+         foundIncrease = true;
+        prevRecall = recall;
+      }
+    }
+  }
+
+  BOOST_REQUIRE(foundIncrease);
+}
+
 BOOST_AUTO_TEST_CASE(LSHTrainTest)
 {
   // This is a not very good test that simply checks that the re-trained LSH




More information about the mlpack-git mailing list