[mlpack-git] master: Mark ComputeRecall() as static. (bee9567)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jun 21 20:27:32 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4129a7c1d7432498b7d2e991fdeb63b3e3c46fe4...bee9567e5c4c8ea688243744eb6aaaef1e850654
>---------------------------------------------------------------
commit bee9567e5c4c8ea688243744eb6aaaef1e850654
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Jun 21 17:27:32 2016 -0700
Mark ComputeRecall() as static.
>---------------------------------------------------------------
bee9567e5c4c8ea688243744eb6aaaef1e850654
src/mlpack/methods/lsh/lsh_search.hpp | 12 ++++++----
src/mlpack/methods/lsh/lsh_search_impl.hpp | 9 ++++---
src/mlpack/tests/lsh_test.cpp | 38 +++++++++---------------------
3 files changed, 23 insertions(+), 36 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 022f6e3..a4e6ca9 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -191,11 +191,15 @@ class LSHSearch
/**
* Compute the recall (% of neighbors found) given the neighbors returned by
- * LSHSearch::Search and a "ground truth" file. Recall in [0, 1]
+ * LSHSearch::Search and a "ground truth" set of neighbors. The recall
+ * returned will be in the range [0, 1].
+ *
+ * @param foundNeighbors Set of neighbors to compute recall of.
+ * @param realNeighbors Set of "ground truth" neighbors to compute recall
+ * against.
*/
- double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
- const arma::Mat<size_t>& realNeighbors);
-
+ static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
+ const arma::Mat<size_t>& realNeighbors);
/**
* Serialize the LSH model.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index b3ad46d..149beab 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -563,10 +563,10 @@ double LSHSearch<SortPolicy>::ComputeRecall(
" must have equal size");
const size_t queries = foundNeighbors.n_cols;
- const size_t neighbors= foundNeighbors.n_rows; //k
+ const size_t neighbors = foundNeighbors.n_rows; // Should be equal to k.
- // recall is set intersection of found and real neighbors
- double found = 0;
+ // The recall is the set intersection of found and real neighbors.
+ size_t found = 0;
for (size_t col = 0; col < queries; ++col)
for (size_t row = 0; row < neighbors; ++row)
for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
@@ -576,8 +576,7 @@ double LSHSearch<SortPolicy>::ComputeRecall(
break;
}
- return found/realNeighbors.n_elem;
-
+ return ((double) found) / realNeighbors.n_elem;
}
template<typename SortPolicy>
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index c8c4ba5..65d1d78 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -16,20 +16,6 @@ using namespace mlpack;
using namespace mlpack::neighbor;
/**
- * Computes Recall (percent of neighbors found correctly).
- */
-double ComputeRecall(
- const arma::Mat<size_t>& lshNeighbors,
- const arma::Mat<size_t>& groundTruth)
-{
- const size_t queries = lshNeighbors.n_cols;
- const size_t neigh = lshNeighbors.n_rows;
-
- const double same = arma::accu(lshNeighbors == groundTruth);
- return same / (static_cast<double>(queries * neigh));
-}
-
-/**
* Generates a point set of four clusters around (0.5, 0.5),
* (3.5, 0.5), (0.5, 3.5), (3.5, 3.5).
*/
@@ -149,7 +135,7 @@ BOOST_AUTO_TEST_CASE(NumTablesTest)
lshTest.Search(qdata, k, lshNeighbors, lshDistances);
// Compute recall for each query.
- lValueRecall[l] = ComputeRecall(lshNeighbors, groundTruth);
+ lValueRecall[l] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
if (l > 0)
{
@@ -221,7 +207,7 @@ BOOST_AUTO_TEST_CASE(HashWidthTest)
lshTest.Search(qdata, k, lshNeighbors, lshDistances);
// Compute recall for each query.
- hValueRecall[h] = ComputeRecall(lshNeighbors, groundTruth);
+ hValueRecall[h] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
if (h > 0)
BOOST_REQUIRE_GE(hValueRecall[h], hValueRecall[h - 1] - epsilon);
@@ -283,7 +269,7 @@ BOOST_AUTO_TEST_CASE(NumProjTest)
lshTest.Search(qdata, k, lshNeighbors, lshDistances);
// Compute recall for each query.
- pValueRecall[p] = ComputeRecall(lshNeighbors, groundTruth);
+ pValueRecall[p] = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
// Don't check the first run; only check that increasing P decreases recall.
if (p > 0)
@@ -339,7 +325,7 @@ BOOST_AUTO_TEST_CASE(RecallTest)
arma::mat lshDistancesExp;
lshTestExp.Search(qdata, k, lshNeighborsExp, lshDistancesExp);
- const double recallExp = ComputeRecall(lshNeighborsExp, groundTruth);
+ const double recallExp = LSHSearch<>::ComputeRecall(lshNeighborsExp, groundTruth);
// This run should have recall higher than the threshold.
BOOST_REQUIRE_GE(recallExp, recallThreshExp);
@@ -361,7 +347,8 @@ BOOST_AUTO_TEST_CASE(RecallTest)
arma::mat lshDistancesChp;
lshTestChp.Search(qdata, k, lshNeighborsChp, lshDistancesChp);
- const double recallChp = ComputeRecall(lshNeighborsChp, groundTruth);
+ const double recallChp = LSHSearch<>::ComputeRecall(lshNeighborsChp,
+ groundTruth);
// This run should have recall lower than the threshold.
BOOST_REQUIRE_LE(recallChp, recallThreshChp);
@@ -522,8 +509,7 @@ BOOST_AUTO_TEST_CASE(RecallTestIdentical)
q1.set_size(k, numQueries);
q1.col(0) = arma::linspace< arma::Col<size_t> >(1, k, k);
- LSHSearch<> lsh;
- BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q1), 1);
+ BOOST_REQUIRE_EQUAL(LSHSearch<>::ComputeRecall(base, q1), 1);
}
/**
@@ -554,8 +540,7 @@ BOOST_AUTO_TEST_CASE(RecallTestPartiallyCorrect)
6 << arma::endr <<
7 << arma::endr;
- LSHSearch<> lsh;
- BOOST_REQUIRE_CLOSE(lsh.ComputeRecall(base, q2), 0.6, 0.0001);
+ BOOST_REQUIRE_CLOSE(LSHSearch<>::ComputeRecall(base, q2), 0.6, 0.0001);
}
/**
@@ -575,8 +560,7 @@ BOOST_AUTO_TEST_CASE(RecallTestIncorrect)
q3.set_size(k, numQueries);
q3.col(0) = arma::linspace< arma::Col<size_t> >(k + 1, 2 * k, k);
- LSHSearch<> lsh;
- BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q3), 0);
+ BOOST_REQUIRE_EQUAL(LSHSearch<>::ComputeRecall(base, q3), 0);
}
/**
@@ -596,8 +580,8 @@ BOOST_AUTO_TEST_CASE(RecallTestException)
arma::Mat<size_t> q4;
q4.set_size(2 * k, numQueries);
- LSHSearch<> lsh;
- BOOST_REQUIRE_THROW(lsh.ComputeRecall(base, q4), std::invalid_argument);
+ BOOST_REQUIRE_THROW(LSHSearch<>::ComputeRecall(base, q4),
+ std::invalid_argument);
}
More information about the mlpack-git
mailing list