[mlpack-git] master: Adds test for ComputeRecall and makes ComputeRecall throw exceptions for invalid arguments (ba02a65)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 9 03:58:20 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2bd1227d0f41dd61e444f3c84c56eefd946014e2...4129a7c1d7432498b7d2e991fdeb63b3e3c46fe4

>---------------------------------------------------------------

commit ba02a65f79d2821fa5e1acc0f0a14942d9096ced
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Thu Jun 9 10:58:20 2016 +0300

    Adds test for ComputeRecall and makes ComputeRecall throw exceptions for invalid arguments


>---------------------------------------------------------------

ba02a65f79d2821fa5e1acc0f0a14942d9096ced
 src/mlpack/methods/lsh/lsh_search_impl.hpp |  5 +++
 src/mlpack/tests/lsh_test.cpp              | 49 ++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+)

diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 92c4773..14617c0 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -576,6 +576,11 @@ double LSHSearch<SortPolicy>::ComputeRecall(
     const arma::Mat<size_t>& foundNeighbors,
     const arma::Mat<size_t>& realNeighbors)
 {
+  if (foundNeighbors.n_rows != realNeighbors.n_rows ||
+      foundNeighbors.n_cols != realNeighbors.n_cols)
+    throw std::invalid_argument("LSHSearch::ComputeRecall(): matrices provided"
+        " must have equal size");
+
   const size_t queries = foundNeighbors.n_cols;
   const size_t neighbors=  foundNeighbors.n_rows; //k
 
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index d425666..7c597c5 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -326,6 +326,55 @@ BOOST_AUTO_TEST_CASE(LSHTrainTest)
   BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
 }
 
+/**
+ * Test: this verifies ComputeRecall works correctly. It inputs a few "base" and
+ * "found" vectors and checks that recall is what expected.
+ */
+BOOST_AUTO_TEST_CASE(ComputeRecallTest)
+{
+  const size_t k = 5; // 5 nearest neighbors
+  const size_t numQueries = 1;
+
+  // base = [1; 2; 3; 4; 5]
+  arma::Mat<size_t> base;
+  base.set_size(k, numQueries);
+  base.col(0) = arma::linspace< arma::Col<size_t> >(1, k, k);
+
+  // q1 = [1; 2; 3; 4; 5]. Expect recall = 1
+  arma::Mat<size_t> q1;
+  q1.set_size(k, numQueries);
+  q1.col(0) = arma::linspace< arma::Col<size_t> >(1, k, k);
+  
+  LSHSearch<> lsh;
+  BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q1), 1);
+
+  // q2 = [2; 3; 4; 6; 7]. Expect recall = 0.6. This is important because this
+  // is a good example of how recall and accuracy differ. Accuracy here would
+  // be 0 but recall should not be.
+  arma::Mat<size_t> q2;
+  q2.set_size(k, numQueries);
+  q2 << 
+    2 << arma::endr << 
+    3 << arma::endr << 
+    4 << arma::endr << 
+    7 << arma::endr << 
+    7 << arma::endr;
+
+  BOOST_REQUIRE_CLOSE(lsh.ComputeRecall(base, q2), 0.6, 0.0001);
+
+  // q3 = [6; 7; 8; 9; 10]. Expected recall = 0
+  arma::Mat<size_t> q3;
+  q3.set_size(k, numQueries);
+  q3.col(0) = arma::linspace< arma::Col<size_t> >(k + 1, 2 * k, k);
+  BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q3), 0);
+
+  // verify that nonsense arguments throw exception
+  arma::Mat<size_t> q4;
+  q4.set_size(2 * k, numQueries);
+  BOOST_REQUIRE_THROW(lsh.ComputeRecall(base, q4), std::invalid_argument);
+
+}
+
 BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
 {
   // If we create an empty LSH model and then call Search(), it should throw an




More information about the mlpack-git mailing list