[mlpack-git] master: Adds ComputeRecall function to LSHSearch (4d98deb)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 8 13:08:10 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2bd1227d0f41dd61e444f3c84c56eefd946014e2...4129a7c1d7432498b7d2e991fdeb63b3e3c46fe4

>---------------------------------------------------------------

commit 4d98deb4de11bedde7e2d55c41ab46a29fe7152c
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Wed Jun 8 20:08:10 2016 +0300

    Adds ComputeRecall function to LSHSearch


>---------------------------------------------------------------

4d98deb4de11bedde7e2d55c41ab46a29fe7152c
 src/mlpack/methods/lsh/lsh_main.cpp        | 20 ++++++++++++++++++++
 src/mlpack/methods/lsh/lsh_search.hpp      |  8 ++++++++
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 23 +++++++++++++++++++++++
 3 files changed, 51 insertions(+)

diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 2894411..4f4040a 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -48,6 +48,7 @@ PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
     "");
 PARAM_STRING("distances_file", "File to output distances into.", "d", "");
 PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+PARAM_STRING("true_neighbors_file", "File of real neighbors to compute recall (printed with -v).", "t", "");
 
 // We can load or save models.
 PARAM_STRING("input_model_file", "File to load LSH model from.  (Cannot be "
@@ -188,6 +189,25 @@ int main(int argc, char *argv[])
 
   Log::Info << "Neighbors computed." << endl;
 
+  // Compute recall, if desired.
+  if (CLI::HasParam("t"))
+  {
+    // read specified filename
+    const string trueNeighborsFile = 
+      CLI::GetParam<string>("true_neighbors_file");
+
+    // load the data
+    arma::Mat<size_t> trueNeighbors;
+    data::Load(trueNeighborsFile, trueNeighbors, true);
+    Log::Info << "Loaded true neighbor indices from '" 
+      << trueNeighborsFile << "'." << endl;
+
+    // Compute Recall and log
+    double recallPercentage = 100 * allkann.ComputeRecall(neighbors, trueNeighbors);
+
+    Log::Info << "Recall: " << recallPercentage << endl;
+  }
+
   // Save output, if desired.
   if (CLI::HasParam("distances_file"))
     data::Save(distancesFile, distances);
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index b42bb7a..fbdf6f3 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -169,6 +169,14 @@ class LSHSearch
               const size_t numTablesToSearch = 0);
 
   /**
+   * Compute the recall (% of neighbors found) given the neighbors returned by
+   * LSHSearch::Search and a "ground truth" file. Recall in [0, 1]
+   */
+  double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
+                       const arma::Mat<size_t>& realNeighbors);
+
+
+  /**
    * Serialize the LSH model.
    *
    * @param ar Archive to serialize to.
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index ad698e1..92c4773 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -572,6 +572,29 @@ Search(const size_t k,
 }
 
 template<typename SortPolicy>
+double LSHSearch<SortPolicy>::ComputeRecall(
+    const arma::Mat<size_t>& foundNeighbors,
+    const arma::Mat<size_t>& realNeighbors)
+{
+  const size_t queries = foundNeighbors.n_cols;
+  const size_t neighbors=  foundNeighbors.n_rows; //k
+
+  // recall is set intersection of found and real neighbors
+  double found = 0;
+  for (size_t col = 0; col < queries; ++col)
+    for (size_t row = 0; row < neighbors; ++row)
+      for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
+        if (realNeighbors(row, col) == foundNeighbors(nei, col))
+        {
+          found++;
+          break;
+        }
+
+  return found/realNeighbors.n_elem;
+
+}
+
+template<typename SortPolicy>
 template<typename Archive>
 void LSHSearch<SortPolicy>::Serialize(Archive& ar,
                                       const unsigned int version)




More information about the mlpack-git mailing list