[mlpack-git] master: Add NeighborSearchRules specialization for Spill Trees. (a599bf8)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:14 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit a599bf8b24cd11e6d324a70eb5f90805bf22e8d9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Wed Jul 27 00:02:18 2016 -0300
Add NeighborSearchRules specialization for Spill Trees.
>---------------------------------------------------------------
a599bf8b24cd11e6d324a70eb5f90805bf22e8d9
src/mlpack/methods/neighbor_search/CMakeLists.txt | 2 +
.../neighbor_search/neighbor_search_rules.hpp | 30 +--
.../neighbor_search/neighbor_search_rules_impl.hpp | 31 +--
...h_rules.hpp => neighbor_search_rules_spill.hpp} | 58 ++---
...pl.hpp => neighbor_search_rules_spill_impl.hpp} | 251 ++++++++++-----------
5 files changed, 141 insertions(+), 231 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 1c51ce4..95fe37b 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -5,6 +5,8 @@ set(SOURCES
neighbor_search_impl.hpp
neighbor_search_rules.hpp
neighbor_search_rules_impl.hpp
+ neighbor_search_rules_spill.hpp
+ neighbor_search_rules_spill_impl.hpp
neighbor_search_stat.hpp
ns_model.hpp
ns_model_impl.hpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 82f5f57..8b73b2d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -9,7 +9,6 @@
#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
#include <mlpack/core/tree/traversal_info.hpp>
-#include <mlpack/core/tree/spill_tree.hpp>
namespace mlpack {
namespace neighbor {
@@ -88,35 +87,11 @@ class NeighborSearchRules
* @param referenceNode Candidate node to be recursed into.
* @param oldScore Old score produced by Score() (or Rescore()).
*/
- template<typename Tree>
double Rescore(const size_t queryIndex,
- Tree& referenceNode,
+ TreeType& referenceNode,
const double oldScore) const;
/**
- * Rescore function specialized for Spill Trees. This function is used to
- * update the score value when doing backtracking. For spill trees, it
- * implements a Hybrid sp-tree search. If the parent node is a overlapping
- * node and we have visited enough points, it decides to prune this node.
- * If the parent node is a non-overlapping node, proper score is returned,
- * so the search can continue with backtracking.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- template<typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
- double Rescore(
- const size_t queryIndex,
- tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- referenceNode,
- const double oldScore) const;
-
- /**
* Get the score for recursion order. A low score indicates priority for
* recursionm while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned).
@@ -235,4 +210,7 @@ class NeighborSearchRules
// Include implementation.
#include "neighbor_search_rules_impl.hpp"
+// Include specialization for Spill Trees.
+#include "neighbor_search_rules_spill.hpp"
+
#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 4adcc29..e40d09e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -144,10 +144,9 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename Tree>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
const size_t queryIndex,
- Tree& /* referenceNode */,
+ TreeType& /* referenceNode */,
const double oldScore) const
{
// If we are already pruning, still prune.
@@ -162,34 +161,6 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- const size_t queryIndex,
- tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- referenceNode,
- double oldScore) const
-{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (referenceNode.Parent() && referenceNode.Parent()->Overlap())
- // Defeatist search (If we have enough points, let's prune).
- if (neighbors(neighbors.n_rows - 1, queryIndex) != (size_t() - 1))
- return DBL_MAX;
-
- // Just check the score again against the distances.
- double bestDistance = distances(distances.n_rows - 1, queryIndex);
- bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
similarity index 79%
copy from src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
copy to src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
index 82f5f57..6f84adf 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
@@ -1,12 +1,13 @@
/**
- * @file neighbor_search_rules.hpp
+ * @file neighbor_search_rules_spill.hpp
* @author Ryan Curtin
+ * @author Marcos Pividori
*
* Defines the pruning rules and base case rules necessary to perform a
- * tree-based search (with an arbitrary tree) for the NeighborSearch class.
+ * tree-based search with Spill Trees for the NeighborSearch class.
*/
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
#include <mlpack/core/tree/traversal_info.hpp>
#include <mlpack/core/tree/spill_tree.hpp>
@@ -15,19 +16,24 @@ namespace mlpack {
namespace neighbor {
/**
- * The NeighborSearchRules class is a template helper class used by
- * NeighborSearch class when performing distance-based neighbor searches. For
- * each point in the query dataset, it keeps track of the k neighbors in the
- * reference dataset which have the 'best' distance according to a given sorting
- * policy.
+ * NeighborSearchRules specialization for Spill Trees.
+ * The main difference with the general implementation is that Score() methods
+ * consider the special case of a overlapping node.
*
* @tparam SortPolicy The sort policy for distances.
* @tparam MetricType The metric to use for computation.
* @tparam TreeType The tree type to use; must adhere to the TreeType API.
*/
-template<typename SortPolicy, typename MetricType, typename TreeType>
-class NeighborSearchRules
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+ StatisticType, MatType, SplitType>>
{
+ typedef tree::SpillTree<MetricType, StatisticType, MatType, SplitType>
+ TreeType;
public:
/**
* Construct the NeighborSearchRules object. This is usually done from within
@@ -88,35 +94,11 @@ class NeighborSearchRules
* @param referenceNode Candidate node to be recursed into.
* @param oldScore Old score produced by Score() (or Rescore()).
*/
- template<typename Tree>
double Rescore(const size_t queryIndex,
- Tree& referenceNode,
+ TreeType& referenceNode,
const double oldScore) const;
/**
- * Rescore function specialized for Spill Trees. This function is used to
- * update the score value when doing backtracking. For spill trees, it
- * implements a Hybrid sp-tree search. If the parent node is a overlapping
- * node and we have visited enough points, it decides to prune this node.
- * If the parent node is a non-overlapping node, proper score is returned,
- * so the search can continue with backtracking.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- template<typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
- double Rescore(
- const size_t queryIndex,
- tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- referenceNode,
- const double oldScore) const;
-
- /**
* Get the score for recursion order. A low score indicates priority for
* recursionm while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned).
@@ -233,6 +215,6 @@ class NeighborSearchRules
} // namespace mlpack
// Include implementation.
-#include "neighbor_search_rules_impl.hpp"
+#include "neighbor_search_rules_spill_impl.hpp"
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
similarity index 69%
copy from src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
copy to src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
index 4adcc29..2ea54ba 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
@@ -1,20 +1,26 @@
/**
- * @file neighbor_search_rules_impl.hpp
+ * @file neighbor_search_rules_spill_impl.hpp
* @author Ryan Curtin
+ * @author Marcos Pividori
*
- * Implementation of NeighborSearchRules.
+ * Implementation of NeighborSearchRules for Spill Trees.
*/
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP
// In case it hasn't been included yet.
-#include "neighbor_search_rules.hpp"
+#include "neighbor_search_rules_spill.hpp"
namespace mlpack {
namespace neighbor {
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+ StatisticType, MatType, SplitType>>::NeighborSearchRules(
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
const size_t k,
@@ -53,8 +59,13 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
candidates.push_back(pqueue);
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+ StatisticType, MatType, SplitType>>::GetResults(
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
@@ -73,69 +84,62 @@ void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
}
};
-template<typename SortPolicy, typename MetricType, typename TreeType>
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+ StatisticType, MatType, SplitType>>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
if (sameSet && (queryIndex == referenceIndex))
return 0.0;
- // If we have already performed this base case, then do not perform it again.
- if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
- return lastBaseCase;
-
double distance = metric.Evaluate(querySet.col(queryIndex),
referenceSet.col(referenceIndex));
++baseCases;
InsertNeighbor(queryIndex, referenceIndex, distance);
- // Cache this information for the next time BaseCase() is called.
- lastQueryIndex = queryIndex;
- lastReferenceIndex = referenceIndex;
- lastBaseCase = distance;
-
return distance;
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::Score(
const size_t queryIndex,
TreeType& referenceNode)
{
++scores; // Count number of Score() calls.
- double distance;
- if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
- {
- // The first point in the tree is the centroid. So we can then calculate
- // the base case between that and the query point.
- double baseCase = -1.0;
- if (tree::TreeTraits<TreeType>::HasSelfChildren)
- {
- // If the parent node is the same, then we have already calculated the
- // base case.
- if ((referenceNode.Parent() != NULL) &&
- (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
- baseCase = referenceNode.Parent()->Stat().LastDistance();
- else
- baseCase = BaseCase(queryIndex, referenceNode.Point(0));
-
- // Save this evaluation.
- referenceNode.Stat().LastDistance() = baseCase;
- }
- distance = SortPolicy::CombineBest(baseCase,
- referenceNode.FurthestDescendantDistance());
- }
- else
+ if (!referenceNode.Parent())
+ return 0;
+
+ if (referenceNode.Parent()->Overlap()) // Defeatist search.
{
- distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
- &referenceNode);
+ const double value = referenceNode.Parent()->SplitValue();
+ const size_t dim = referenceNode.Parent()->SplitDimension();
+ const bool left = &referenceNode == referenceNode.Parent()->Left();
+
+ if ((left && querySet(dim, queryIndex) <= value) ||
+ (!left && querySet(dim, queryIndex) > value))
+ return 0;
+ else
+ return DBL_MAX;
}
+ double distance = SortPolicy::BestPointToNodeDistance(
+ querySet.col(queryIndex), &referenceNode);
+
// Compare against the best k'th distance for this query point so far.
double bestDistance = candidates[queryIndex].top().first;
bestDistance = SortPolicy::Relax(bestDistance, epsilon);
@@ -143,12 +147,16 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename Tree>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::Rescore(
const size_t queryIndex,
- Tree& /* referenceNode */,
- const double oldScore) const
+ TreeType& /* referenceNode */,
+ double oldScore) const
{
// If we are already pruning, still prune.
if (oldScore == DBL_MAX)
@@ -161,40 +169,33 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
template<typename StatisticType,
typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- const size_t queryIndex,
- tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- referenceNode,
- double oldScore) const
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode)
{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
+ ++scores; // Count number of Score() calls
- if (referenceNode.Parent() && referenceNode.Parent()->Overlap())
- // Defeatist search (If we have enough points, let's prune).
- if (neighbors(neighbors.n_rows - 1, queryIndex) != (size_t() - 1))
- return DBL_MAX;
+ if (!referenceNode.Parent())
+ return 0;
- // Just check the score again against the distances.
- double bestDistance = distances(distances.n_rows - 1, queryIndex);
- bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
+ if (referenceNode.Parent()->Overlap()) // Defeatist search.
+ {
+ const double value = referenceNode.Parent()->SplitValue();
+ const size_t dim = referenceNode.Parent()->SplitDimension();
+ const bool left = &referenceNode == referenceNode.Parent()->Left();
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode)
-{
- ++scores; // Count number of Score() calls.
+ if ((left && queryNode.Bound()[dim].Lo() <= value) ||
+ (!left && queryNode.Bound()[dim].Hi() > value))
+ return 0;
+ else
+ return DBL_MAX;
+ }
// Update our bound.
const double bestDistance = CalculateBound(queryNode);
@@ -209,14 +210,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
const double score = traversalInfo.LastScore();
double adjustedScore;
- // We want to set adjustedScore to be the distance between the centroid of the
- // last query node and last reference node. We will do this by adjusting the
- // last score. In some cases, we can just use the last base case.
- if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
- {
- adjustedScore = traversalInfo.LastBaseCase();
- }
- else if (score == 0.0) // Nothing we can do here.
+ if (score == 0.0) // Nothing we can do here.
{
adjustedScore = 0.0;
}
@@ -289,48 +283,14 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
// Can we prune?
if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
{
- if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
- {
- // There isn't any need to set the traversal information because no
- // descendant combinations will be visited, and those are the only
- // combinations that would depend on the traversal information.
- return DBL_MAX;
- }
+ // There isn't any need to set the traversal information because no
+ // descendant combinations will be visited, and those are the only
+ // combinations that would depend on the traversal information.
+ return DBL_MAX;
}
- double distance;
- if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
- {
- // The first point in the node is the centroid, so we can calculate the
- // distance between the two points using BaseCase() and then find the
- // bounds. This is potentially loose for non-ball bounds.
- double baseCase = -1.0;
- if (tree::TreeTraits<TreeType>::HasSelfChildren &&
- (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
- (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
- {
- // We already calculated it.
- baseCase = traversalInfo.LastBaseCase();
- }
- else
- {
- baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
- }
-
- distance = SortPolicy::CombineBest(baseCase,
- queryNode.FurthestDescendantDistance() +
- referenceNode.FurthestDescendantDistance());
-
- lastQueryIndex = queryNode.Point(0);
- lastReferenceIndex = referenceNode.Point(0);
- lastBaseCase = baseCase;
-
- traversalInfo.LastBaseCase() = baseCase;
- }
- else
- {
- distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
- }
+ double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode);
if (SortPolicy::IsBetter(distance, bestDistance))
{
@@ -350,8 +310,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::Rescore(
TreeType& queryNode,
TreeType& /* referenceNode */,
const double oldScore) const
@@ -359,6 +324,9 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
if (oldScore == DBL_MAX)
return oldScore;
+ if (oldScore == 0)
+ return oldScore;
+
// Update our bound.
const double bestDistance = CalculateBound(queryNode);
@@ -367,8 +335,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
// Calculate the bound for a given query node in its current state and update
// it.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::
CalculateBound(TreeType& queryNode) const
{
// This is an adapted form of the B(N_q) function in the paper
@@ -490,9 +463,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-InsertNeighbor(
+template<typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType,
+ typename SortPolicy,
+ typename MetricType>
+inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+ MetricType, StatisticType, MatType, SplitType>>::InsertNeighbor(
const size_t queryIndex,
const size_t neighbor,
const double distance)
@@ -510,4 +487,4 @@ InsertNeighbor(
} // namespace neighbor
} // namespace mlpack
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP
More information about the mlpack-git
mailing list