[mlpack-git] master: Add EffectiveError() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn. (9215515)

gitdub at mlpack.org gitdub at mlpack.org
Tue Aug 16 19:05:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 92155154bc8721055702ae3d22c33542f2f07a61
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Aug 16 20:05:32 2016 -0300

    Add EffectiveError() to NeighborSearch class, and a proper command line option to mlpack_knn and mlpack_kfn.


>---------------------------------------------------------------

92155154bc8721055702ae3d22c33542f2f07a61
 src/mlpack/methods/neighbor_search/kfn_main.cpp    | 20 ++++++++++
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 43 ++++++++--------------
 .../methods/neighbor_search/neighbor_search.hpp    | 18 +++++++++
 .../neighbor_search/neighbor_search_impl.hpp       | 36 ++++++++++++++++++
 4 files changed, 90 insertions(+), 27 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index b223d7c..9a4659e 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -46,6 +46,9 @@ PARAM_STRING_IN("reference_file", "File containing the reference dataset.", "r",
     "");
 PARAM_STRING_OUT("distances_file", "File to output distances into.", "d");
 PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
+PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
+    "the effective error (average relative error) (it is printed when -v is "
+    "specified).", "D", "");
 
 // The option exists to load or save models.
 PARAM_STRING_IN("input_model_file", "File containing pre-trained kFN model.",
@@ -273,6 +276,23 @@ int main(int argc, char *argv[])
       data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
     if (CLI::HasParam("distances_file"))
       data::Save(CLI::GetParam<string>("distances_file"), distances);
+
+    // Calculate the effective error, if desired.
+    if (CLI::HasParam("true_distances_file"))
+    {
+      const string trueDistancesFile = CLI::GetParam<string>(
+          "true_distances_file");
+      arma::mat trueDistances;
+      data::Load(trueDistancesFile, trueDistances, true);
+
+      if (trueDistances.n_rows != distances.n_rows ||
+          trueDistances.n_cols != distances.n_cols)
+        Log::Fatal << "The true distances file must have the same number of "
+          << "values than the set of distances being queried!" << endl;
+
+      Log::Info << "Effective error: " << KFN::EffectiveError(distances,
+          trueDistances) << endl;
+    }
   }
 
   if (CLI::HasParam("output_model_file"))
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 391a303..8dd8d53 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -48,6 +48,9 @@ PARAM_STRING_IN("reference_file", "File containing the reference dataset.", "r",
     "");
 PARAM_STRING_OUT("distances_file", "File to output distances into.", "d");
 PARAM_STRING_OUT("neighbors_file", "File to output neighbors into.", "n");
+PARAM_STRING_IN("true_distances_file", "File of true distances to compute "
+    "the effective error (average relative error) (it is printed when -v is "
+    "specified).", "D", "");
 
 // The option exists to load or save models.
 PARAM_STRING_IN("input_model_file", "File containing pre-trained kNN model.",
@@ -82,9 +85,6 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "S");
 PARAM_DOUBLE_IN("epsilon", "If specified, will do approximate nearest neighbor "
     "search with given relative error.", "e", 0);
-PARAM_STRING_IN("effective_error", "If specified, will compare the results "
-    "against the provided distances file, and will print the average relative "
-    "error.", "E", "");
 
 // Convenience typedef.
 typedef NSModel<NearestNeighborSort> KNNModel;
@@ -294,31 +294,20 @@ int main(int argc, char *argv[])
       data::Save(CLI::GetParam<string>("distances_file"), distances);
 
     // Calculate the effective error, if desired.
-    if (CLI::HasParam("effective_error"))
+    if (CLI::HasParam("true_distances_file"))
     {
-      const string exactFile = CLI::GetParam<string>("effective_error");
-      arma::mat distancesExact;
-      data::Load(exactFile, distancesExact, true);
-
-      if (distancesExact.n_elem != distances.n_elem)
-        Log::Fatal << "The effective error file must have the same number of "
-          << "values than the set of distances being queried!" << endl;
-
-      double effectiveError = 0;
-      size_t cases = 0;
-      for (size_t i = 0; i < distances.n_elem; i++)
-      {
-        if (distancesExact(i) != 0 && distances(i) != DBL_MAX)
-        {
-          effectiveError += (distances(i) - distancesExact(i)) /
-              distancesExact(i);
-          cases++;
-        }
-      }
-      if (cases)
-        effectiveError /= cases;
-
-      Log::Info << "Effective error: " << effectiveError << endl;
+      const string trueDistancesFile = CLI::GetParam<string>(
+          "true_distances_file");
+      arma::mat trueDistances;
+      data::Load(trueDistancesFile, trueDistances, true);
+
+      if (trueDistances.n_rows != distances.n_rows ||
+          trueDistances.n_cols != distances.n_cols)
+        Log::Fatal << "The true distances file must have the same number of "
+            << "values than the set of distances being queried!" << endl;
+
+      Log::Info << "Effective error: " << KNN::EffectiveError(distances,
+          trueDistances) << endl;
     }
   }
 
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 75e6a55..7db9ee6 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -271,6 +271,24 @@ class NeighborSearch
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
 
+  /**
+   * Calculate the average relative error (effective error) between the
+   * distances calculated and the true distances provided.  The input matrices
+   * must have the same size.
+   *
+   * Cases where the true distance is zero (the same point) or the calculated
+   * distance is SortPolicy::WorstDistance() (didn't find enough points) will be
+   * ignored.
+   *
+   * @param foundDistances Matrix storing lists of calculated distances for each
+   *     query point.
+   * @param realDistances Matrix storing lists of true best distances for each
+   *     query point.
+   * @return Average relative error.
+   */
+  static double EffectiveError(arma::mat& foundDistances,
+                               arma::mat& realDistances);
+
   //! Return the total number of base case evaluations performed during the last
   //! search.
   size_t BaseCases() const { return baseCases; }
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index cbc89ff..30e7716 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -716,6 +716,42 @@ Search(const size_t k,
   }
 }
 
+//! Calculate the average relative error.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+TraversalType>::EffectiveError(arma::mat& foundDistances,
+                               arma::mat& realDistances)
+{
+  if (foundDistances.n_rows != realDistances.n_rows ||
+      foundDistances.n_cols != realDistances.n_cols)
+    throw std::invalid_argument("matrices provided must have equal size");
+
+  double effectiveError = 0;
+  size_t numCases = 0;
+
+  for (size_t i = 0; i < foundDistances.n_elem; i++)
+  {
+    if (realDistances(i) != 0 &&
+        foundDistances(i) != SortPolicy::WorstDistance())
+    {
+      effectiveError += fabs(foundDistances(i) - realDistances(i)) /
+          realDistances(i);
+      numCases++;
+    }
+  }
+
+  if (numCases)
+    effectiveError /= numCases;
+
+  return effectiveError;
+}
+
 //! Serialize the NeighborSearch model.
 template<typename SortPolicy,
          typename MetricType,




More information about the mlpack-git mailing list