[mlpack-git] master: Add EffectiveError() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn. (9215515)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Aug 16 19:05:32 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 92155154bc8721055702ae3d22c33542f2f07a61
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Aug 16 20:05:32 2016 -0300
Add EffectiveError() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn.
>---------------------------------------------------------------
92155154bc8721055702ae3d22c33542f2f07a61
src/mlpack/methods/neighbor_search/kfn_main.cpp | 20 ++++++++++
src/mlpack/methods/neighbor_search/knn_main.cpp | 43 ++++++++--------------
.../methods/neighbor_search/neighbor_search.hpp | 18 +++++++++
.../neighbor_search/neighbor_search_impl.hpp | 36 ++++++++++++++++++
4 files changed, 90 insertions(+), 27 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index b223d7c..9a4659e 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -46,6 +46,9 @@ PARAM_STRING_IN("reference_file", "File containing the reference dataset.", "r",
"");
PARAM_STRING_OUT("distances_file", "File to output distances into.", "d");
PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
+PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
+ "the effective error (average relative error) (it is printed when -v is "
+ "specified).", "D", "");
// The option exists to load or save models.
PARAM_STRING_IN("input_model_file", "File containing pre-trained kFN model.",
@@ -273,6 +276,23 @@ int main(int argc, char *argv[])
data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
if (CLI::HasParam("distances_file"))
data::Save(CLI::GetParam<string>("distances_file"), distances);
+
+ // Calculate the effective error, if desired.
+ if (CLI::HasParam("true_distances_file"))
+ {
+ const string trueDistancesFile = CLI::GetParam<string>(
+ "true_distances_file");
+ arma::mat trueDistances;
+ data::Load(trueDistancesFile, trueDistances, true);
+
+ if (trueDistances.n_rows != distances.n_rows ||
+ trueDistances.n_cols != distances.n_cols)
+ Log::Fatal << "The true distances file must have the same number of "
+ << "values than the set of distances being queried!" << endl;
+
+ Log::Info << "Effective error: " << KFN::EffectiveError(distances,
+ trueDistances) << endl;
+ }
}
if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 391a303..8dd8d53 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -48,6 +48,9 @@ PARAM_STRING_IN("reference_file", "File containing the reference dataset.", "r",
"");
PARAM_STRING_OUT("distances_file", "File to output distances into.", "d");
PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
+PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
+ "the effective error (average relative error) (it is printed when -v is "
+ "specified).", "D", "");
// The option exists to load or save models.
PARAM_STRING_IN("input_model_file", "File containing pre-trained kNN model.",
@@ -82,9 +85,6 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "S");
PARAM_DOUBLE_IN("epsilon", "If specified, will do approximate nearest neighbor "
"search with given relative error.", "e", 0);
-PARAM_STRING_IN("effective_error", "If specified, will compare the results "
- "against the provided distances file, and will print the average relative "
- "error.", "E", "");
// Convenience typedef.
typedef NSModel<NearestNeighborSort> KNNModel;
@@ -294,31 +294,20 @@ int main(int argc, char *argv[])
data::Save(CLI::GetParam<string>("distances_file"), distances);
// Calculate the effective error, if desired.
- if (CLI::HasParam("effective_error"))
+ if (CLI::HasParam("true_distances_file"))
{
- const string exactFile = CLI::GetParam<string>("effective_error");
- arma::mat distancesExact;
- data::Load(exactFile, distancesExact, true);
-
- if (distancesExact.n_elem != distances.n_elem)
- Log::Fatal << "The effective error file must have the same number of "
- << "values than the set of distances being queried!" << endl;
-
- double effectiveError = 0;
- size_t cases = 0;
- for (size_t i = 0; i < distances.n_elem; i++)
- {
- if (distancesExact(i) != 0 && distances(i) != DBL_MAX)
- {
- effectiveError += (distances(i) - distancesExact(i)) /
- distancesExact(i);
- cases++;
- }
- }
- if (cases)
- effectiveError /= cases;
-
- Log::Info << "Effective error: " << effectiveError << endl;
+ const string trueDistancesFile = CLI::GetParam<string>(
+ "true_distances_file");
+ arma::mat trueDistances;
+ data::Load(trueDistancesFile, trueDistances, true);
+
+ if (trueDistances.n_rows != distances.n_rows ||
+ trueDistances.n_cols != distances.n_cols)
+ Log::Fatal << "The true distances file must have the same number of "
+ << "values than the set of distances being queried!" << endl;
+
+ Log::Info << "Effective error: " << KNN::EffectiveError(distances,
+ trueDistances) << endl;
}
}
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 75e6a55..7db9ee6 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -271,6 +271,24 @@ class NeighborSearch
arma::Mat<size_t>& neighbors,
arma::mat& distances);
+ /**
+ * Calculate the average relative error (effective error) between the
+ * distances calculated and the true distances provided. The input matrices
+ * must have the same size.
+ *
+ * Cases where the true distance is zero (the same point) or the calculated
+ * distance is SortPolicy::WorstDistance() (didn't find enough points) will be
+ * ignored.
+ *
+ * @param foundDistances Matrix storing lists of calculated distances for each
+ * query point.
+ * @param realDistances Matrix storing lists of true best distances for each
+ * query point.
+ * @return Average relative error.
+ */
+ static double EffectiveError(arma::mat& foundDistances,
+ arma::mat& realDistances);
+
//! Return the total number of base case evaluations performed during the last
//! search.
size_t BaseCases() const { return baseCases; }
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index cbc89ff..30e7716 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -716,6 +716,42 @@ Search(const size_t k,
}
}
+//! Calculate the average relative error.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+TraversalType>::EffectiveError(arma::mat& foundDistances,
+ arma::mat& realDistances)
+{
+ if (foundDistances.n_rows != realDistances.n_rows ||
+ foundDistances.n_cols != realDistances.n_cols)
+ throw std::invalid_argument("matrices provided must have equal size");
+
+ double effectiveError = 0;
+ size_t numCases = 0;
+
+ for (size_t i = 0; i < foundDistances.n_elem; i++)
+ {
+ if (realDistances(i) != 0 &&
+ foundDistances(i) != SortPolicy::WorstDistance())
+ {
+ effectiveError += fabs(foundDistances(i) - realDistances(i)) /
+ realDistances(i);
+ numCases++;
+ }
+ }
+
+ if (numCases)
+ effectiveError /= numCases;
+
+ return effectiveError;
+}
+
//! Serialize the NeighborSearch model.
template<typename SortPolicy,
typename MetricType,
More information about the mlpack-git
mailing list