[mlpack-git] master: Add Recall() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn. (3e79d27)

gitdub at mlpack.org gitdub at mlpack.org
Tue Aug 16 19:08:10 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 3e79d273545cb5d5694f6c8e7370a0dbbd95e5e6
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Aug 16 20:08:10 2016 -0300

    Add Recall() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn.


>---------------------------------------------------------------

3e79d273545cb5d5694f6c8e7370a0dbbd95e5e6
 src/mlpack/methods/neighbor_search/kfn_main.cpp    | 20 ++++++++++++++-
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 18 ++++++++++++++
 .../methods/neighbor_search/neighbor_search.hpp    | 14 +++++++++++
 .../neighbor_search/neighbor_search_impl.hpp       | 29 ++++++++++++++++++++++
 4 files changed, 80 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 9a4659e..f56b890 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -49,6 +49,8 @@ PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
 PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
     "the effective error (average relative error) (it is printed when -v is "
     "specified).", "D", "");
+PARAM_STRING_IN("true_neighbors_file", "File of true neighbors to compute the "
+    "recall (it is printed when -v is specified).", "T", "");
 
 // The option exists to load or save models.
 PARAM_STRING_IN("input_model_file", "File containing pre-trained kFN model.",
@@ -288,11 +290,27 @@ int main(int argc, char *argv[])
       if (trueDistances.n_rows != distances.n_rows ||
           trueDistances.n_cols != distances.n_cols)
         Log::Fatal << "The true distances file must have the same number of "
-          << "values than the set of distances being queried!" << endl;
+            << "values than the set of distances being queried!" << endl;
 
       Log::Info << "Effective error: " << KFN::EffectiveError(distances,
           trueDistances) << endl;
     }
+
+    // Calculate the recall, if desired.
+    if (CLI::HasParam("true_neighbors_file"))
+    {
+      const string trueNeighborsFile = CLI::GetParam<string>(
+          "true_neighbors_file");
+      arma::Mat<size_t> trueNeighbors;
+      data::Load(trueNeighborsFile, trueNeighbors, true);
+
+      if (trueNeighbors.n_rows != neighbors.n_rows ||
+          trueNeighbors.n_cols != neighbors.n_cols)
+        Log::Fatal << "The true neighbors file must have the same number of "
+            << "values than the set of neighbors being queried!" << endl;
+
+      Log::Info << "Recall: " << KFN::Recall(neighbors, trueNeighbors) << endl;
+    }
   }
 
   if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 8dd8d53..974865f 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -51,6 +51,8 @@ PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
 PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
     "the effective error (average relative error) (it is printed when -v is "
     "specified).", "D", "");
+PARAM_STRING_IN("true_neighbors_file", "File of true neighbors to compute the "
+    "recall (it is printed when -v is specified).", "T", "");
 
 // The option exists to load or save models.
 PARAM_STRING_IN("input_model_file", "File containing pre-trained kNN model.",
@@ -309,6 +311,22 @@ int main(int argc, char *argv[])
       Log::Info << "Effective error: " << KNN::EffectiveError(distances,
           trueDistances) << endl;
     }
+
+    // Calculate the recall, if desired.
+    if (CLI::HasParam("true_neighbors_file"))
+    {
+      const string trueNeighborsFile = CLI::GetParam<string>(
+          "true_neighbors_file");
+      arma::Mat<size_t> trueNeighbors;
+      data::Load(trueNeighborsFile, trueNeighbors, true);
+
+      if (trueNeighbors.n_rows != neighbors.n_rows ||
+          trueNeighbors.n_cols != neighbors.n_cols)
+        Log::Fatal << "The true neighbors file must have the same number of "
+            << "values than the set of neighbors being queried!" << endl;
+
+      Log::Info << "Recall: " << KNN::Recall(neighbors, trueNeighbors) << endl;
+    }
   }
 
   if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 7db9ee6..be16b05 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -289,6 +289,20 @@ class NeighborSearch
   static double EffectiveError(arma::mat& foundDistances,
                                arma::mat& realDistances);
 
+  /**
+   * Calculate the recall (% of neighbors found) given the list of found
+   * neighbors and the true set of neighbors.  The recall returned will be in
+   * the range [0, 1].
+   *
+   * @param foundNeighbors Matrix storing lists of calculated neighbors for each
+   *     query point.
+   * @param realNeighbors Matrix storing lists of true best neighbors for each
+   *     query point.
+   * @return Recall.
+   */
+  static double Recall(arma::Mat<size_t>& foundNeighbors,
+                       arma::Mat<size_t>& realNeighbors);
+
   //! Return the total number of base case evaluations performed during the last
   //! search.
   size_t BaseCases() const { return baseCases; }
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 30e7716..238680e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -752,6 +752,35 @@ TraversalType>::EffectiveError(arma::mat& foundDistances,
   return effectiveError;
 }
 
+//! Calculate the recall.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+TraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
+                       arma::Mat<size_t>& realNeighbors)
+{
+  if (foundNeighbors.n_rows != realNeighbors.n_rows ||
+      foundNeighbors.n_cols != realNeighbors.n_cols)
+    throw std::invalid_argument("matrices provided must have equal size");
+
+  size_t found = 0;
+  for (size_t col = 0; col < foundNeighbors.n_cols; ++col)
+    for (size_t row = 0; row < foundNeighbors.n_rows; ++row)
+      for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
+        if (foundNeighbors(row, col) == realNeighbors(nei, col))
+        {
+          found++;
+          break;
+        }
+
+  return ((double) found) / realNeighbors.n_elem;
+}
+
 //! Serialize the NeighborSearch model.
 template<typename SortPolicy,
          typename MetricType,




More information about the mlpack-git mailing list