[mlpack-git] master: Modify KNN/KFN to include Approximate Neighbor Search. (6db6de5)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 22 14:09:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a9f5622c8a14409111f2d71bf5c0f8aaa8ad4ae1...37fda23945b4f998cd5fa6ec011ae345236c8552
>---------------------------------------------------------------
commit 6db6de598389d780d51fe19a1d8484f9f1071920
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jun 2 15:53:45 2016 -0300
Modify KNN/KFN to include Approximate Neighbor Search.
>---------------------------------------------------------------
6db6de598389d780d51fe19a1d8484f9f1071920
src/mlpack/methods/neighbor_search/kfn_main.cpp | 12 +++++-
src/mlpack/methods/neighbor_search/knn_main.cpp | 14 +++++--
.../methods/neighbor_search/neighbor_search.hpp | 15 ++++++++
.../neighbor_search/neighbor_search_impl.hpp | 32 +++++++++++----
.../neighbor_search/neighbor_search_rules.hpp | 4 ++
.../neighbor_search/neighbor_search_rules_impl.hpp | 10 ++++-
src/mlpack/methods/neighbor_search/ns_model.hpp | 17 +++++++-
.../methods/neighbor_search/ns_model_impl.hpp | 45 ++++++++++++++++++----
.../sort_policies/furthest_neighbor_sort.hpp | 17 ++++++++
.../sort_policies/nearest_neighbor_sort.hpp | 15 ++++++++
10 files changed, 158 insertions(+), 23 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index d680740..a2fbf2e 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -72,6 +72,8 @@ PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "s");
+PARAM_DOUBLE("epsilon", "If specified, will do approximate furthest neighbor "
+ "search with given relative error.", "e", 0);
// Convenience typedef.
typedef NSModel<FurthestNeighborSort> KFNModel;
@@ -138,6 +140,12 @@ int main(int argc, char *argv[])
Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater than 0."
<< endl;
+ // Sanity check on epsilon.
+ const double epsilon = CLI::GetParam<double>("epsilon");
+ if (epsilon < 0)
+ Log::Fatal << "Invalid epsilon: " << epsilon << ". Must be non-negative. "
+ << endl;
+
// We either have to load the reference data, or we have to load the model.
NSModel<FurthestNeighborSort> kfn;
const bool naive = CLI::HasParam("naive");
@@ -175,7 +183,8 @@ int main(int argc, char *argv[])
Log::Info << "Loaded reference data from '" << referenceFile << "' ("
<< referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl;
- kfn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode);
+ kfn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
+ epsilon);
}
else
{
@@ -191,6 +200,7 @@ int main(int argc, char *argv[])
kfn.SingleMode() = CLI::HasParam("single_mode");
kfn.Naive() = CLI::HasParam("naive");
kfn.LeafSize() = size_t(lsInt);
+ kfn.Epsilon() = epsilon;
}
// Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 4957e88..880f5db 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -74,6 +74,8 @@ PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "S");
+PARAM_DOUBLE("epsilon", "If specified, will do approximate nearest neighbor "
+ "search with given relative error.", "e", 0);
// Convenience typedef.
typedef NSModel<NearestNeighborSort> KNNModel;
@@ -137,10 +139,14 @@ int main(int argc, char *argv[])
// Sanity check on leaf size.
const int lsInt = CLI::GetParam<int>("leaf_size");
if (lsInt < 1)
- {
Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
"than 0." << endl;
- }
+
+ // Sanity check on epsilon.
+ const double epsilon = CLI::GetParam<double>("epsilon");
+ if (epsilon < 0)
+ Log::Fatal << "Invalid epsilon: " << epsilon << ". Must be non-negative. "
+ << endl;
// We either have to load the reference data, or we have to load the model.
NSModel<NearestNeighborSort> knn;
@@ -180,7 +186,8 @@ int main(int argc, char *argv[])
<< referenceSet.n_rows << " x " << referenceSet.n_cols << ")."
<< endl;
- knn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode);
+ knn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
+ epsilon);
}
else
{
@@ -196,6 +203,7 @@ int main(int argc, char *argv[])
knn.SingleMode() = CLI::HasParam("single_mode");
knn.Naive() = CLI::HasParam("naive");
knn.LeafSize() = size_t(lsInt);
+ knn.Epsilon() = epsilon;
}
// Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 999f261..f1acea4 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -84,11 +84,13 @@ class NeighborSearch
* dual-tree search). This overrides singleMode (if it is set to true).
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
+ * @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
NeighborSearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
+ const double epsilon = 0,
const MetricType metric = MetricType());
/**
@@ -108,11 +110,13 @@ class NeighborSearch
* dual-tree search). This overrides singleMode (if it is set to true).
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
+ * @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
NeighborSearch(MatType&& referenceSet,
const bool naive = false,
const bool singleMode = false,
+ const double epsilon = 0,
const MetricType metric = MetricType());
/**
@@ -138,10 +142,12 @@ class NeighborSearch
* @param referenceSet Set of reference points corresponding to referenceTree.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
+ * @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated distance metric.
*/
NeighborSearch(Tree* referenceTree,
const bool singleMode = false,
+ const double epsilon = 0,
const MetricType metric = MetricType());
/**
@@ -152,10 +158,12 @@ class NeighborSearch
* @param naive Whether to use naive search.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
+ * @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated metric.
*/
NeighborSearch(const bool naive = false,
const bool singleMode = false,
+ const double epsilon = 0,
const MetricType metric = MetricType());
@@ -270,6 +278,11 @@ class NeighborSearch
//! Modify whether or not search is done in single-tree mode.
bool& SingleMode() { return singleMode; }
+ //! Access the relative error to be considered in approximate search.
+ double Epsilon() const { return epsilon; }
+ //! Modify the relative error to be considered in approximate search.
+ double& Epsilon() { return epsilon; }
+
//! Access the reference dataset.
const MatType& ReferenceSet() const { return *referenceSet; }
@@ -294,6 +307,8 @@ class NeighborSearch
bool naive;
//! Indicates if single-tree search is being used (as opposed to dual-tree).
bool singleMode;
+ //! Indicates the relative error to be considered in approximate search.
+ double epsilon;
//! Instantiation of metric.
MetricType metric;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index d86f514..2d7468b 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -75,6 +75,7 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
NeighborSearch(const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
+ const double epsilon,
const MetricType metric) :
referenceTree(naive ? NULL :
BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
@@ -83,12 +84,14 @@ NeighborSearch(const MatType& referenceSetIn,
setOwner(false),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
+ epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
- // Nothing to do.
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
}
// Construct the object.
@@ -103,6 +106,7 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
NeighborSearch(MatType&& referenceSetIn,
const bool naive,
const bool singleMode,
+ const double epsilon,
const MetricType metric) :
referenceTree(naive ? NULL :
BuildTree<MatType, Tree>(std::move(referenceSetIn),
@@ -113,12 +117,14 @@ NeighborSearch(MatType&& referenceSetIn,
setOwner(naive),
naive(naive),
singleMode(!naive && singleMode),
+ epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
- // Nothing to do.
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
}
// Construct the object.
@@ -132,6 +138,7 @@ template<typename SortPolicy,
NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
NeighborSearch(Tree* referenceTree,
const bool singleMode,
+ const double epsilon,
const MetricType metric) :
referenceTree(referenceTree),
referenceSet(&referenceTree->Dataset()),
@@ -139,12 +146,14 @@ NeighborSearch(Tree* referenceTree,
setOwner(false),
naive(false),
singleMode(singleMode),
+ epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
- // Nothing else to initialize.
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
}
// Construct the object without a reference dataset.
@@ -158,6 +167,7 @@ template<typename SortPolicy,
NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
NeighborSearch(const bool naive,
const bool singleMode,
+ const double epsilon,
const MetricType metric) :
referenceTree(NULL),
referenceSet(new MatType()), // Empty matrix.
@@ -165,11 +175,14 @@ NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
setOwner(true),
naive(naive),
singleMode(singleMode),
+ epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
// Build the tree on the empty dataset, if necessary.
if (!naive)
{
@@ -364,7 +377,8 @@ Search(const MatType& querySet,
if (naive)
{
// Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ epsilon);
// The naive brute-force traversal.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -376,7 +390,8 @@ Search(const MatType& querySet,
else if (singleMode)
{
// Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ epsilon);
// Create the traverser.
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -402,7 +417,7 @@ Search(const MatType& querySet,
// Create the helper object for the tree traversal.
RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
- *distancePtr, metric);
+ *distancePtr, metric, epsilon);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
@@ -527,7 +542,8 @@ Search(Tree* queryTree,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric,
+ epsilon);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
@@ -598,7 +614,7 @@ Search(const size_t k,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
- metric, true /* don't return the same point as nearest neighbor */);
+ metric, epsilon, true /* don't return the same point as nearest neighbor */);
if (naive)
{
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 474d22b..47a7933 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -22,6 +22,7 @@ class NeighborSearchRules
arma::Mat<size_t>& neighbors,
arma::mat& distances,
MetricType& metric,
+ const double epsilon = 0,
const bool sameSet = false);
/**
* Get the distance from the query point to the reference point.
@@ -120,6 +121,9 @@ class NeighborSearchRules
//! Denotes whether or not the reference and query sets are the same.
bool sameSet;
+ //! Relative error to be considered in approximate search.
+ const double epsilon;
+
//! The last query point BaseCase() was called with.
size_t lastQueryIndex;
//! The last reference point BaseCase() was called with.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index cc2b957..6edf103 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -20,6 +20,7 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
arma::Mat<size_t>& neighbors,
arma::mat& distances,
MetricType& metric,
+ const double epsilon,
const bool sameSet) :
referenceSet(referenceSet),
querySet(querySet),
@@ -27,6 +28,7 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
distances(distances),
metric(metric),
sameSet(sameSet),
+ epsilon(epsilon),
lastQueryIndex(querySet.n_cols),
lastReferenceIndex(referenceSet.n_cols),
baseCases(0),
@@ -112,7 +114,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
// Compare against the best k'th distance for this query point so far.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
@@ -128,7 +131,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
return oldScore;
// Just check the score again against the distances.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
}
@@ -419,6 +423,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
queryNode.Stat().SecondBound() = bestDistance;
queryNode.Stat().AuxBound() = auxDistance;
+ worstDistance = SortPolicy::Relax(worstDistance, epsilon);
+
if (SortPolicy::IsBetter(worstDistance, bestDistance))
return worstDistance;
else
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index d87549e..db3331a 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -178,6 +178,16 @@ class NaiveVisitor : public boost::static_visitor<bool&>
};
/**
+ * EpsilonVisitor exposes the Epsilon method of the given NSType.
+ */
+class EpsilonVisitor : public boost::static_visitor<double&>
+{
+ public:
+ template<typename NSType>
+ double& operator()(NSType *ns) const;
+};
+
+/**
* ReferenceSetVisitor exposes the referenceSet of the given NSType.
*/
class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
@@ -266,6 +276,10 @@ class NSModel
bool Naive() const;
bool& Naive();
+ //! Expose Epsilon.
+ double Epsilon() const;
+ double& Epsilon();
+
//! Expose leafSize.
size_t LeafSize() const { return leafSize; }
size_t& LeafSize() { return leafSize; }
@@ -282,7 +296,8 @@ class NSModel
void BuildModel(arma::mat&& referenceSet,
const size_t leafSize,
const bool naive,
- const bool singleMode);
+ const bool singleMode,
+ const double epsilon = 0);
//! Perform neighbor search. The query set will be reordered.
void Search(arma::mat&& querySet,
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 5ed9772..bbca3d2 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -185,6 +185,15 @@ bool& NaiveVisitor::operator()(NSType* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
+//! Expose the Epsilon method of the given NSType.
+template<typename NSType>
+double& EpsilonVisitor::operator()(NSType* ns) const
+{
+ if (ns)
+ return ns->Epsilon();
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Expose the referenceSet of the given NSType.
template<typename NSType>
const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
@@ -293,12 +302,25 @@ bool& NSModel<SortPolicy>::Naive()
return boost::apply_visitor(NaiveVisitor(), nSearch);
}
+template<typename SortPolicy>
+double NSModel<SortPolicy>::Epsilon() const
+{
+ return boost::apply_visitor(EpsilonVisitor(), nSearch);
+}
+
+template<typename SortPolicy>
+double& NSModel<SortPolicy>::Epsilon()
+{
+ return boost::apply_visitor(EpsilonVisitor(), nSearch);
+}
+
//! Build the reference tree.
template<typename SortPolicy>
void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
const size_t leafSize,
const bool naive,
- const bool singleMode)
+ const bool singleMode,
+ const double epsilon)
{
// Initialize random basis if necessary.
if (randomBasis)
@@ -348,23 +370,26 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
switch (treeType)
{
case KD_TREE:
- nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode,
+ epsilon);
break;
case COVER_TREE:
nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(naive,
- singleMode);
+ singleMode, epsilon);
break;
case R_TREE:
- nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode, epsilon);
break;
case R_STAR_TREE:
- nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode,
+ epsilon);
break;
case BALL_TREE:
- nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode,
+ epsilon);
break;
case X_TREE:
- nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode, epsilon);
break;
}
@@ -389,7 +414,11 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
if (randomBasis)
querySet = q * querySet;
- Log::Info << "Searching for " << k << " nearest neighbors with ";
+ Log::Info << "Searching for " << k;
+ if (Epsilon() != 0)
+ Log::Info << " approximate nearest neighbors (e=" << Epsilon() << ") with ";
+ else
+ Log::Info << " nearest neighbors with ";
if (!Naive() && !SingleMode())
Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
else if (!Naive())
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
index 87a7262..a69c167 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
@@ -145,6 +145,23 @@ class FurthestNeighborSort
*/
static inline double CombineWorst(const double a, const double b)
{ return std::max(a - b, 0.0); }
+
+ /**
+ * Return the given value relaxed.
+ *
+ * @param value Value to relax.
+ * @param epsilon Relative error (non-negative).
+ *
+ * @return double Value relaxed.
+ */
+ static inline double Relax(const double value, const double epsilon)
+ {
+ if (value == 0)
+ return 0;
+ if (value == DBL_MAX || epsilon >= 1)
+ return DBL_MAX;
+ return (1 / (1 - epsilon)) * value;
+ }
};
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
index f57635a..42a08b0 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
@@ -150,6 +150,21 @@ class NearestNeighborSort
return DBL_MAX;
return a + b;
}
+
+ /**
+ * Return the given value relaxed.
+ *
+ * @param value Value to relax.
+ * @param epsilon Relative error (non-negative).
+ *
+ * @return double Value relaxed.
+ */
+ static inline double Relax(const double value, const double epsilon)
+ {
+ if (value == DBL_MAX)
+ return DBL_MAX;
+ return (1 / (1 + epsilon)) * value;
+ }
};
} // namespace neighbor
More information about the mlpack-git
mailing list