[mlpack-git] master: Modify KNN/KFN to include Approximate Neighbor Search. (6db6de5)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 22 14:09:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a9f5622c8a14409111f2d71bf5c0f8aaa8ad4ae1...37fda23945b4f998cd5fa6ec011ae345236c8552

>---------------------------------------------------------------

commit 6db6de598389d780d51fe19a1d8484f9f1071920
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jun 2 15:53:45 2016 -0300

    Modify KNN/KFN to include Approximate Neighbor Search.


>---------------------------------------------------------------

6db6de598389d780d51fe19a1d8484f9f1071920
 src/mlpack/methods/neighbor_search/kfn_main.cpp    | 12 +++++-
 src/mlpack/methods/neighbor_search/knn_main.cpp    | 14 +++++--
 .../methods/neighbor_search/neighbor_search.hpp    | 15 ++++++++
 .../neighbor_search/neighbor_search_impl.hpp       | 32 +++++++++++----
 .../neighbor_search/neighbor_search_rules.hpp      |  4 ++
 .../neighbor_search/neighbor_search_rules_impl.hpp | 10 ++++-
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 17 +++++++-
 .../methods/neighbor_search/ns_model_impl.hpp      | 45 ++++++++++++++++++----
 .../sort_policies/furthest_neighbor_sort.hpp       | 17 ++++++++
 .../sort_policies/nearest_neighbor_sort.hpp        | 15 ++++++++
 10 files changed, 158 insertions(+), 23 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index d680740..a2fbf2e 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -72,6 +72,8 @@ PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "s");
+PARAM_DOUBLE("epsilon", "If specified, will do approximate furthest neighbor "
+    "search with given relative error.", "e", 0);
 
 // Convenience typedef.
 typedef NSModel<FurthestNeighborSort> KFNModel;
@@ -138,6 +140,12 @@ int main(int argc, char *argv[])
     Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater than 0."
         << endl;
 
+  // Sanity check on epsilon.
+  const double epsilon = CLI::GetParam<double>("epsilon");
+  if (epsilon < 0)
+    Log::Fatal << "Invalid epsilon: " << epsilon << ".  Must be non-negative. "
+        << endl;
+
   // We either have to load the reference data, or we have to load the model.
   NSModel<FurthestNeighborSort> kfn;
   const bool naive = CLI::HasParam("naive");
@@ -175,7 +183,8 @@ int main(int argc, char *argv[])
     Log::Info << "Loaded reference data from '" << referenceFile << "' ("
         << referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl;
 
-    kfn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode);
+    kfn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
+        epsilon);
   }
   else
   {
@@ -191,6 +200,7 @@ int main(int argc, char *argv[])
     kfn.SingleMode() = CLI::HasParam("single_mode");
     kfn.Naive() = CLI::HasParam("naive");
     kfn.LeafSize() = size_t(lsInt);
+    kfn.Epsilon() = epsilon;
   }
 
   // Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 4957e88..880f5db 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -74,6 +74,8 @@ PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "S");
+PARAM_DOUBLE("epsilon", "If specified, will do approximate nearest neighbor "
+    "search with given relative error.", "e", 0);
 
 // Convenience typedef.
 typedef NSModel<NearestNeighborSort> KNNModel;
@@ -137,10 +139,14 @@ int main(int argc, char *argv[])
   // Sanity check on leaf size.
   const int lsInt = CLI::GetParam<int>("leaf_size");
   if (lsInt < 1)
-  {
     Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
         "than 0." << endl;
-  }
+
+  // Sanity check on epsilon.
+  const double epsilon = CLI::GetParam<double>("epsilon");
+  if (epsilon < 0)
+    Log::Fatal << "Invalid epsilon: " << epsilon << ".  Must be non-negative. "
+        << endl;
 
   // We either have to load the reference data, or we have to load the model.
   NSModel<NearestNeighborSort> knn;
@@ -180,7 +186,8 @@ int main(int argc, char *argv[])
         << referenceSet.n_rows << " x " << referenceSet.n_cols << ")."
         << endl;
 
-    knn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode);
+    knn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
+        epsilon);
   }
   else
   {
@@ -196,6 +203,7 @@ int main(int argc, char *argv[])
     knn.SingleMode() = CLI::HasParam("single_mode");
     knn.Naive() = CLI::HasParam("naive");
     knn.LeafSize() = size_t(lsInt);
+    knn.Epsilon() = epsilon;
   }
 
   // Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 999f261..f1acea4 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -84,11 +84,13 @@ class NeighborSearch
    *      dual-tree search).  This overrides singleMode (if it is set to true).
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
+   * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
   NeighborSearch(const MatType& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
+                 const double epsilon = 0,
                  const MetricType metric = MetricType());
 
   /**
@@ -108,11 +110,13 @@ class NeighborSearch
    *      dual-tree search).  This overrides singleMode (if it is set to true).
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
+   * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
   NeighborSearch(MatType&& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
+                 const double epsilon = 0,
                  const MetricType metric = MetricType());
 
   /**
@@ -138,10 +142,12 @@ class NeighborSearch
    * @param referenceSet Set of reference points corresponding to referenceTree.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
+   * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated distance metric.
    */
   NeighborSearch(Tree* referenceTree,
                  const bool singleMode = false,
+                 const double epsilon = 0,
                  const MetricType metric = MetricType());
 
   /**
@@ -152,10 +158,12 @@ class NeighborSearch
    * @param naive Whether to use naive search.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
+   * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated metric.
    */
   NeighborSearch(const bool naive = false,
                  const bool singleMode = false,
+                 const double epsilon = 0,
                  const MetricType metric = MetricType());
 
 
@@ -270,6 +278,11 @@ class NeighborSearch
   //! Modify whether or not search is done in single-tree mode.
   bool& SingleMode() { return singleMode; }
 
+  //! Access the relative error to be considered in approximate search.
+  double Epsilon() const { return epsilon; }
+  //! Modify the relative error to be considered in approximate search.
+  double& Epsilon() { return epsilon; }
+
   //! Access the reference dataset.
   const MatType& ReferenceSet() const { return *referenceSet; }
 
@@ -294,6 +307,8 @@ class NeighborSearch
   bool naive;
   //! Indicates if single-tree search is being used (as opposed to dual-tree).
   bool singleMode;
+  //! Indicates the relative error to be considered in approximate search.
+  double epsilon;
 
   //! Instantiation of metric.
   MetricType metric;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index d86f514..2d7468b 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -75,6 +75,7 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 NeighborSearch(const MatType& referenceSetIn,
                const bool naive,
                const bool singleMode,
+               const double epsilon,
                const MetricType metric) :
     referenceTree(naive ? NULL :
         BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
@@ -83,12 +84,14 @@ NeighborSearch(const MatType& referenceSetIn,
     setOwner(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
+    epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Nothing to do.
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
 }
 
 // Construct the object.
@@ -103,6 +106,7 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 NeighborSearch(MatType&& referenceSetIn,
                const bool naive,
                const bool singleMode,
+               const double epsilon,
                const MetricType metric) :
     referenceTree(naive ? NULL :
         BuildTree<MatType, Tree>(std::move(referenceSetIn),
@@ -113,12 +117,14 @@ NeighborSearch(MatType&& referenceSetIn,
     setOwner(naive),
     naive(naive),
     singleMode(!naive && singleMode),
+    epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Nothing to do.
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
 }
 
 // Construct the object.
@@ -132,6 +138,7 @@ template<typename SortPolicy,
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 NeighborSearch(Tree* referenceTree,
                const bool singleMode,
+               const double epsilon,
                const MetricType metric) :
     referenceTree(referenceTree),
     referenceSet(&referenceTree->Dataset()),
@@ -139,12 +146,14 @@ NeighborSearch(Tree* referenceTree,
     setOwner(false),
     naive(false),
     singleMode(singleMode),
+    epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Nothing else to initialize.
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
 }
 
 // Construct the object without a reference dataset.
@@ -158,6 +167,7 @@ template<typename SortPolicy,
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
     NeighborSearch(const bool naive,
                    const bool singleMode,
+                   const double epsilon,
                    const MetricType metric) :
     referenceTree(NULL),
     referenceSet(new MatType()), // Empty matrix.
@@ -165,11 +175,14 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
     setOwner(true),
     naive(naive),
     singleMode(singleMode),
+    epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
   // Build the tree on the empty dataset, if necessary.
   if (!naive)
   {
@@ -364,7 +377,8 @@ Search(const MatType& querySet,
   if (naive)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+        epsilon);
 
     // The naive brute-force traversal.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -376,7 +390,8 @@ Search(const MatType& querySet,
   else if (singleMode)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+        epsilon);
 
     // Create the traverser.
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -402,7 +417,7 @@ Search(const MatType& querySet,
 
     // Create the helper object for the tree traversal.
     RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
-        *distancePtr, metric);
+        *distancePtr, metric, epsilon);
 
     // Create the traverser.
     TraversalType<RuleType> traverser(rules);
@@ -527,7 +542,8 @@ Search(Tree* queryTree,
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric);
+  RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric,
+      epsilon);
 
   // Create the traverser.
   TraversalType<RuleType> traverser(rules);
@@ -598,7 +614,7 @@ Search(const size_t k,
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
-      metric, true /* don't return the same point as nearest neighbor */);
+      metric, epsilon, true /* don't return the same point as nearest neighbor */);
 
   if (naive)
   {
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 474d22b..47a7933 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -22,6 +22,7 @@ class NeighborSearchRules
                       arma::Mat<size_t>& neighbors,
                       arma::mat& distances,
                       MetricType& metric,
+                      const double epsilon = 0,
                       const bool sameSet = false);
   /**
    * Get the distance from the query point to the reference point.
@@ -120,6 +121,9 @@ class NeighborSearchRules
   //! Denotes whether or not the reference and query sets are the same.
   bool sameSet;
 
+  //! Relative error to be considered in approximate search.
+  const double epsilon;
+
   //! The last query point BaseCase() was called with.
   size_t lastQueryIndex;
   //! The last reference point BaseCase() was called with.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index cc2b957..6edf103 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -20,6 +20,7 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances,
     MetricType& metric,
+    const double epsilon,
     const bool sameSet) :
     referenceSet(referenceSet),
     querySet(querySet),
@@ -27,6 +28,7 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     distances(distances),
     metric(metric),
     sameSet(sameSet),
+    epsilon(epsilon),
     lastQueryIndex(querySet.n_cols),
     lastReferenceIndex(referenceSet.n_cols),
     baseCases(0),
@@ -112,7 +114,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   }
 
   // Compare against the best k'th distance for this query point so far.
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -128,7 +131,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     return oldScore;
 
   // Just check the score again against the distances.
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
 }
@@ -419,6 +423,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   queryNode.Stat().SecondBound() = bestDistance;
   queryNode.Stat().AuxBound() = auxDistance;
 
+  worstDistance = SortPolicy::Relax(worstDistance, epsilon);
+
   if (SortPolicy::IsBetter(worstDistance, bestDistance))
     return worstDistance;
   else
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index d87549e..db3331a 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -178,6 +178,16 @@ class NaiveVisitor : public boost::static_visitor<bool&>
 };
 
 /**
+ * EpsilonVisitor exposes the Epsilon method of the given NSType.
+ */
+class EpsilonVisitor : public boost::static_visitor<double&>
+{
+ public:
+  template<typename NSType>
+  double& operator()(NSType *ns) const;
+};
+
+/**
  * ReferenceSetVisitor exposes the referenceSet of the given NSType.
  */
 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
@@ -266,6 +276,10 @@ class NSModel
   bool Naive() const;
   bool& Naive();
 
+  //! Expose Epsilon.
+  double Epsilon() const;
+  double& Epsilon();
+
   //! Expose leafSize.
   size_t LeafSize() const { return leafSize; }
   size_t& LeafSize() { return leafSize; }
@@ -282,7 +296,8 @@ class NSModel
   void BuildModel(arma::mat&& referenceSet,
                   const size_t leafSize,
                   const bool naive,
-                  const bool singleMode);
+                  const bool singleMode,
+                  const double epsilon = 0);
 
   //! Perform neighbor search.  The query set will be reordered.
   void Search(arma::mat&& querySet,
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 5ed9772..bbca3d2 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -185,6 +185,15 @@ bool& NaiveVisitor::operator()(NSType* ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
+//! Expose the Epsilon method of the given NSType.
+template<typename NSType>
+double& EpsilonVisitor::operator()(NSType* ns) const
+{
+  if (ns)
+    return ns->Epsilon();
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Expose the referenceSet of the given NSType.
 template<typename NSType>
 const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
@@ -293,12 +302,25 @@ bool& NSModel<SortPolicy>::Naive()
   return boost::apply_visitor(NaiveVisitor(), nSearch);
 }
 
+template<typename SortPolicy>
+double NSModel<SortPolicy>::Epsilon() const
+{
+  return boost::apply_visitor(EpsilonVisitor(), nSearch);
+}
+
+template<typename SortPolicy>
+double& NSModel<SortPolicy>::Epsilon()
+{
+  return boost::apply_visitor(EpsilonVisitor(), nSearch);
+}
+
 //! Build the reference tree.
 template<typename SortPolicy>
 void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
                                      const size_t leafSize,
                                      const bool naive,
-                                     const bool singleMode)
+                                     const bool singleMode,
+                                     const double epsilon)
 {
   // Initialize random basis if necessary.
   if (randomBasis)
@@ -348,23 +370,26 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
   switch (treeType)
   {
     case KD_TREE:
-      nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode,
+          epsilon);
       break;
     case COVER_TREE:
       nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(naive,
-          singleMode);
+          singleMode, epsilon);
       break;
     case R_TREE:
-      nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode, epsilon);
       break;
     case R_STAR_TREE:
-      nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode,
+          epsilon);
       break;
     case BALL_TREE:
-      nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode,
+          epsilon);
       break;
     case X_TREE:
-      nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode, epsilon);
       break;
   }
 
@@ -389,7 +414,11 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
   if (randomBasis)
     querySet = q * querySet;
 
-  Log::Info << "Searching for " << k << " nearest neighbors with ";
+  Log::Info << "Searching for " << k;
+  if (Epsilon() != 0)
+    Log::Info << " approximate nearest neighbors (e=" << Epsilon() << ") with ";
+  else
+    Log::Info << " nearest neighbors with ";
   if (!Naive() && !SingleMode())
     Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
   else if (!Naive())
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
index 87a7262..a69c167 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
@@ -145,6 +145,23 @@ class FurthestNeighborSort
    */
   static inline double CombineWorst(const double a, const double b)
   { return std::max(a - b, 0.0); }
+
+  /**
+   * Return the given value relaxed.
+   *
+   * @param value Value to relax.
+   * @param epsilon Relative error (non-negative).
+   *
+   * @return double Value relaxed.
+   */
+  static inline double Relax(const double value, const double epsilon)
+  {
+    if (value == 0)
+      return 0;
+    if (value == DBL_MAX || epsilon >= 1)
+      return DBL_MAX;
+    return (1 / (1 - epsilon)) * value;
+  }
 };
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
index f57635a..42a08b0 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
@@ -150,6 +150,21 @@ class NearestNeighborSort
       return DBL_MAX;
     return a + b;
   }
+
+  /**
+   * Return the given value relaxed.
+   *
+   * @param value Value to relax.
+   * @param epsilon Relative error (non-negative).
+   *
+   * @return double Value relaxed.
+   */
+  static inline double Relax(const double value, const double epsilon)
+  {
+    if (value == DBL_MAX)
+      return DBL_MAX;
+    return (1 / (1 + epsilon)) * value;
+  }
 };
 
 } // namespace neighbor




More information about the mlpack-git mailing list