[mlpack-git] master: Mark ComputeRecall() as static. (bee9567)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jun 21 20:27:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4129a7c1d7432498b7d2e991fdeb63b3e3c46fe4...bee9567e5c4c8ea688243744eb6aaaef1e850654

>---------------------------------------------------------------

commit bee9567e5c4c8ea688243744eb6aaaef1e850654
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Jun 21 17:27:32 2016 -0700

    Mark ComputeRecall() as static.


>---------------------------------------------------------------

bee9567e5c4c8ea688243744eb6aaaef1e850654
 src/mlpack/methods/lsh/lsh_search.hpp      | 12 ++++++----
 src/mlpack/methods/lsh/lsh_search_impl.hpp |  9 ++++---
 src/mlpack/tests/lsh_test.cpp              | 38 +++++++++---------------------
 3 files changed, 23 insertions(+), 36 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 022f6e3..a4e6ca9 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -191,11 +191,15 @@ class LSHSearch
 
   /**
    * Compute the recall (% of neighbors found) given the neighbors returned by
-   * LSHSearch::Search and a "ground truth" file. Recall in [0, 1]
+   * LSHSearch::Search and a "ground truth" set of neighbors.  The recall
+   * returned will be in the range [0, 1].
+   *
+   * @param foundNeighbors Set of neighbors to compute recall of.
+   * @param realNeighbors Set of "ground truth" neighbors to compute recall
+   *     against.
    */
-  double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
-                       const arma::Mat<size_t>& realNeighbors);
-
+  static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
+                              const arma::Mat<size_t>& realNeighbors);
 
   /**
    * Serialize the LSH model.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b3ad46d..149beab 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -563,10 +563,10 @@ double LSHSearch<SortPolicy>::ComputeRecall(
         " must have equal size");
 
   const size_t queries = foundNeighbors.n_cols;
-  const size_t neighbors=  foundNeighbors.n_rows; //k
+  const size_t neighbors = foundNeighbors.n_rows; // Should be equal to k.
 
-  // recall is set intersection of found and real neighbors
-  double found = 0;
+  // The recall is the set intersection of found and real neighbors.
+  size_t found = 0;
   for (size_t col = 0; col < queries; ++col)
     for (size_t row = 0; row < neighbors; ++row)
       for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
@@ -576,8 +576,7 @@ double LSHSearch<SortPolicy>::ComputeRecall(
           break;
         }
 
-  return found/realNeighbors.n_elem;
-
+  return ((double) found) / realNeighbors.n_elem;
 }
 
 template<typename SortPolicy>
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index c8c4ba5..65d1d78 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -16,20 +16,6 @@ using namespace mlpack;
 using namespace mlpack::neighbor;
 
 /**
- * Computes Recall (percent of neighbors found correctly).
- */
-double ComputeRecall(
-    const arma::Mat<size_t>& lshNeighbors,
-    const arma::Mat<size_t>& groundTruth)
-{
-  const size_t queries = lshNeighbors.n_cols;
-  const size_t neigh = lshNeighbors.n_rows;
-
-  const double same = arma::accu(lshNeighbors == groundTruth);
-  return same / (static_cast<double>(queries * neigh));
-}
-
-/**
  * Generates a point set of four clusters around (0.5, 0.5),
  * (3.5, 0.5), (0.5, 3.5), (3.5, 3.5).
  */
@@ -149,7 +135,7 @@ BOOST_AUTO_TEST_CASE(NumTablesTest)
       lshTest.Search(qdata, k, lshNeighbors, lshDistances);
 
       // Compute recall for each query.
-      lValueRecall[l] = ComputeRecall(lshNeighbors, groundTruth);
+      lValueRecall[l] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
 
       if (l > 0)
       {
@@ -221,7 +207,7 @@ BOOST_AUTO_TEST_CASE(HashWidthTest)
     lshTest.Search(qdata, k, lshNeighbors, lshDistances);
 
     // Compute recall for each query.
-    hValueRecall[h] = ComputeRecall(lshNeighbors, groundTruth);
+    hValueRecall[h] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
 
     if (h > 0)
       BOOST_REQUIRE_GE(hValueRecall[h], hValueRecall[h - 1] - epsilon);
@@ -283,7 +269,7 @@ BOOST_AUTO_TEST_CASE(NumProjTest)
     lshTest.Search(qdata, k, lshNeighbors, lshDistances);
 
     // Compute recall for each query.
-    pValueRecall[p] = ComputeRecall(lshNeighbors, groundTruth);
+    pValueRecall[p] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
 
     // Don't check the first run; only check that increasing P decreases recall.
     if (p > 0)
@@ -339,7 +325,7 @@ BOOST_AUTO_TEST_CASE(RecallTest)
   arma::mat lshDistancesExp;
   lshTestExp.Search(qdata, k, lshNeighborsExp, lshDistancesExp);
 
-  const double recallExp = ComputeRecall(lshNeighborsExp, groundTruth);
+  const double recallExp = LSHSearch<>::ComputeRecall(lshNeighborsExp, groundTruth);
 
   // This run should have recall higher than the threshold.
   BOOST_REQUIRE_GE(recallExp, recallThreshExp);
@@ -361,7 +347,8 @@ BOOST_AUTO_TEST_CASE(RecallTest)
   arma::mat lshDistancesChp;
   lshTestChp.Search(qdata, k, lshNeighborsChp, lshDistancesChp);
 
-  const double recallChp = ComputeRecall(lshNeighborsChp, groundTruth);
+  const double recallChp = LSHSearch<>::ComputeRecall(lshNeighborsChp,
+      groundTruth);
 
   // This run should have recall lower than the threshold.
   BOOST_REQUIRE_LE(recallChp, recallThreshChp);
@@ -522,8 +509,7 @@ BOOST_AUTO_TEST_CASE(RecallTestIdentical)
   q1.set_size(k, numQueries);
   q1.col(0) = arma::linspace< arma::Col<size_t> >(1, k, k);
   
-  LSHSearch<> lsh;
-  BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q1), 1);
+  BOOST_REQUIRE_EQUAL(LSHSearch<>::ComputeRecall(base, q1), 1);
 }
 
 /**
@@ -554,8 +540,7 @@ BOOST_AUTO_TEST_CASE(RecallTestPartiallyCorrect)
     6 << arma::endr << 
     7 << arma::endr;
 
-  LSHSearch<> lsh;
-  BOOST_REQUIRE_CLOSE(lsh.ComputeRecall(base, q2), 0.6, 0.0001);
+  BOOST_REQUIRE_CLOSE(LSHSearch<>::ComputeRecall(base, q2), 0.6, 0.0001);
 }
 
 /**
@@ -575,8 +560,7 @@ BOOST_AUTO_TEST_CASE(RecallTestIncorrect)
   q3.set_size(k, numQueries);
   q3.col(0) = arma::linspace< arma::Col<size_t> >(k + 1, 2 * k, k);
 
-  LSHSearch<> lsh;
-  BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q3), 0);
+  BOOST_REQUIRE_EQUAL(LSHSearch<>::ComputeRecall(base, q3), 0);
 }
 
 /**
@@ -596,8 +580,8 @@ BOOST_AUTO_TEST_CASE(RecallTestException)
   arma::Mat<size_t> q4;
   q4.set_size(2 * k, numQueries);
 
-  LSHSearch<> lsh;
-  BOOST_REQUIRE_THROW(lsh.ComputeRecall(base, q4), std::invalid_argument);
+  BOOST_REQUIRE_THROW(LSHSearch<>::ComputeRecall(base, q4),
+      std::invalid_argument);
 
 }
 




More information about the mlpack-git mailing list