[mlpack-git] master: Adds ComputeRecall function to LSHSearch (4d98deb)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 8 13:08:10 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/2bd1227d0f41dd61e444f3c84c56eefd946014e2...4129a7c1d7432498b7d2e991fdeb63b3e3c46fe4
>---------------------------------------------------------------
commit 4d98deb4de11bedde7e2d55c41ab46a29fe7152c
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Wed Jun 8 20:08:10 2016 +0300
Adds ComputeRecall function to LSHSearch
>---------------------------------------------------------------
4d98deb4de11bedde7e2d55c41ab46a29fe7152c
src/mlpack/methods/lsh/lsh_main.cpp | 20 ++++++++++++++++++++
src/mlpack/methods/lsh/lsh_search.hpp | 8 ++++++++
src/mlpack/methods/lsh/lsh_search_impl.hpp | 23 +++++++++++++++++++++++
3 files changed, 51 insertions(+)
diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 2894411..4f4040a 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -48,6 +48,7 @@ PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
"");
PARAM_STRING("distances_file", "File to output distances into.", "d", "");
PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+PARAM_STRING("true_neighbors_file", "File of real neighbors to compute recall (printed with -v).", "t", "");
// We can load or save models.
PARAM_STRING("input_model_file", "File to load LSH model from. (Cannot be "
@@ -188,6 +189,25 @@ int main(int argc, char *argv[])
Log::Info << "Neighbors computed." << endl;
+ // Compute recall, if desired.
+ if (CLI::HasParam("t"))
+ {
+ // read specified filename
+ const string trueNeighborsFile =
+ CLI::GetParam<string>("true_neighbors_file");
+
+ // load the data
+ arma::Mat<size_t> trueNeighbors;
+ data::Load(trueNeighborsFile, trueNeighbors, true);
+ Log::Info << "Loaded true neighbor indices from '"
+ << trueNeighborsFile << "'." << endl;
+
+ // Compute Recall and log
+ double recallPercentage = 100 * allkann.ComputeRecall(neighbors, trueNeighbors);
+
+ Log::Info << "Recall: " << recallPercentage << endl;
+ }
+
// Save output, if desired.
if (CLI::HasParam("distances_file"))
data::Save(distancesFile, distances);
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..fbdf6f3 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -169,6 +169,14 @@ class LSHSearch
const size_t numTablesToSearch = 0);
/**
+ * Compute the recall (% of neighbors found) given the neighbors returned by
+ * LSHSearch::Search and a "ground truth" file. Recall in [0, 1]
+ */
+ double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
+ const arma::Mat<size_t>& realNeighbors);
+
+
+ /**
* Serialize the LSH model.
*
* @param ar Archive to serialize to.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1..92c4773 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -572,6 +572,29 @@ Search(const size_t k,
}
template<typename SortPolicy>
+double LSHSearch<SortPolicy>::ComputeRecall(
+ const arma::Mat<size_t>& foundNeighbors,
+ const arma::Mat<size_t>& realNeighbors)
+{
+ const size_t queries = foundNeighbors.n_cols;
+ const size_t neighbors= foundNeighbors.n_rows; //k
+
+ // recall is set intersection of found and real neighbors
+ double found = 0;
+ for (size_t col = 0; col < queries; ++col)
+ for (size_t row = 0; row < neighbors; ++row)
+ for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
+ if (realNeighbors(row, col) == foundNeighbors(nei, col))
+ {
+ found++;
+ break;
+ }
+
+ return found/realNeighbors.n_elem;
+
+}
+
+template<typename SortPolicy>
template<typename Archive>
void LSHSearch<SortPolicy>::Serialize(Archive& ar,
const unsigned int version)
More information about the mlpack-git
mailing list