[mlpack-git] master: Add Recall() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn. (3e79d27)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Aug 16 19:08:10 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 3e79d273545cb5d5694f6c8e7370a0dbbd95e5e6
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Aug 16 20:08:10 2016 -0300
Add Recall() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn.
>---------------------------------------------------------------
3e79d273545cb5d5694f6c8e7370a0dbbd95e5e6
src/mlpack/methods/neighbor_search/kfn_main.cpp | 20 ++++++++++++++-
src/mlpack/methods/neighbor_search/knn_main.cpp | 18 ++++++++++++++
.../methods/neighbor_search/neighbor_search.hpp | 14 +++++++++++
.../neighbor_search/neighbor_search_impl.hpp | 29 ++++++++++++++++++++++
4 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 9a4659e..f56b890 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -49,6 +49,8 @@ PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
"the effective error (average relative error) (it is printed when -v is "
"specified).", "D", "");
+PARAM_STRING_IN("true_neighbors_file", "File of true neighbors to compute the "
+ "recall (it is printed when -v is specified).", "T", "");
// The option exists to load or save models.
PARAM_STRING_IN("input_model_file", "File containing pre-trained kFN model.",
@@ -288,11 +290,27 @@ int main(int argc, char *argv[])
if (trueDistances.n_rows != distances.n_rows ||
trueDistances.n_cols != distances.n_cols)
Log::Fatal << "The true distances file must have the same number of "
- << "values than the set of distances being queried!" << endl;
+ << "values than the set of distances being queried!" << endl;
Log::Info << "Effective error: " << KFN::EffectiveError(distances,
trueDistances) << endl;
}
+
+ // Calculate the recall, if desired.
+ if (CLI::HasParam("true_neighbors_file"))
+ {
+ const string trueNeighborsFile = CLI::GetParam<string>(
+ "true_neighbors_file");
+ arma::Mat<size_t> trueNeighbors;
+ data::Load(trueNeighborsFile, trueNeighbors, true);
+
+ if (trueNeighbors.n_rows != neighbors.n_rows ||
+ trueNeighbors.n_cols != neighbors.n_cols)
+ Log::Fatal << "The true neighbors file must have the same number of "
+ << "values than the set of neighbors being queried!" << endl;
+
+ Log::Info << "Recall: " << KFN::Recall(neighbors, trueNeighbors) << endl;
+ }
}
if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 8dd8d53..974865f 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -51,6 +51,8 @@ PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
"the effective error (average relative error) (it is printed when -v is "
"specified).", "D", "");
+PARAM_STRING_IN("true_neighbors_file", "File of true neighbors to compute the "
+ "recall (it is printed when -v is specified).", "T", "");
// The option exists to load or save models.
PARAM_STRING_IN("input_model_file", "File containing pre-trained kNN model.",
@@ -309,6 +311,22 @@ int main(int argc, char *argv[])
Log::Info << "Effective error: " << KNN::EffectiveError(distances,
trueDistances) << endl;
}
+
+ // Calculate the recall, if desired.
+ if (CLI::HasParam("true_neighbors_file"))
+ {
+ const string trueNeighborsFile = CLI::GetParam<string>(
+ "true_neighbors_file");
+ arma::Mat<size_t> trueNeighbors;
+ data::Load(trueNeighborsFile, trueNeighbors, true);
+
+ if (trueNeighbors.n_rows != neighbors.n_rows ||
+ trueNeighbors.n_cols != neighbors.n_cols)
+ Log::Fatal << "The true neighbors file must have the same number of "
+ << "values than the set of neighbors being queried!" << endl;
+
+ Log::Info << "Recall: " << KNN::Recall(neighbors, trueNeighbors) << endl;
+ }
}
if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 7db9ee6..be16b05 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -289,6 +289,20 @@ class NeighborSearch
static double EffectiveError(arma::mat& foundDistances,
arma::mat& realDistances);
+ /**
+ * Calculate the recall (% of neighbors found) given the list of found
+ * neighbors and the true set of neighbors. The recall returned will be in
+ * the range [0, 1].
+ *
+ * @param foundNeighbors Matrix storing lists of calculated neighbors for each
+ * query point.
+ * @param realNeighbors Matrix storing lists of true best neighbors for each
+ * query point.
+ * @return Recall.
+ */
+ static double Recall(arma::Mat<size_t>& foundNeighbors,
+ arma::Mat<size_t>& realNeighbors);
+
//! Return the total number of base case evaluations performed during the last
//! search.
size_t BaseCases() const { return baseCases; }
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 30e7716..238680e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -752,6 +752,35 @@ TraversalType>::EffectiveError(arma::mat& foundDistances,
return effectiveError;
}
+//! Calculate the recall.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+TraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
+ arma::Mat<size_t>& realNeighbors)
+{
+ if (foundNeighbors.n_rows != realNeighbors.n_rows ||
+ foundNeighbors.n_cols != realNeighbors.n_cols)
+ throw std::invalid_argument("matrices provided must have equal size");
+
+ size_t found = 0;
+ for (size_t col = 0; col < foundNeighbors.n_cols; ++col)
+ for (size_t row = 0; row < foundNeighbors.n_rows; ++row)
+ for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
+ if (foundNeighbors(row, col) == realNeighbors(nei, col))
+ {
+ found++;
+ break;
+ }
+
+ return ((double) found) / realNeighbors.n_elem;
+}
+
//! Serialize the NeighborSearch model.
template<typename SortPolicy,
typename MetricType,
More information about the mlpack-git
mailing list