[mlpack-git] master: Implement defeatist search in the Rescore() method, with a specialization for Spill Trees. (e1c01ca)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:46 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit e1c01cacbe532ec3344ea627d3be7d41bfd5e74b
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Mon Jul 18 10:48:59 2016 -0300

    Implement defeatist search in the Rescore() method, with a specialization for Spill Trees.


>---------------------------------------------------------------

e1c01cacbe532ec3344ea627d3be7d41bfd5e74b
 .../neighbor_search/neighbor_search_rules.hpp      | 27 ++++++++++++++++++-
 .../neighbor_search/neighbor_search_rules_impl.hpp | 31 +++++++++++++++++++++-
 2 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 0347da5..82f5f57 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -9,6 +9,7 @@
 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
 
 #include <mlpack/core/tree/traversal_info.hpp>
+#include <mlpack/core/tree/spill_tree.hpp>
 
 namespace mlpack {
 namespace neighbor {
@@ -87,11 +88,35 @@ class NeighborSearchRules
    * @param referenceNode Candidate node to be recursed into.
    * @param oldScore Old score produced by Score() (or Rescore()).
    */
+  template<typename Tree>
   double Rescore(const size_t queryIndex,
-                 TreeType& referenceNode,
+                 Tree& referenceNode,
                  const double oldScore) const;
 
   /**
+   * Rescore function specialized for Spill Trees.  This function is used to
+   * update the score value when doing backtracking.  For spill trees, it
+   * implements a Hybrid sp-tree search.  If the parent node is a overlapping
+   * node and we have visited enough points, it decides to prune this node.
+   * If the parent node is a non-overlapping node, proper score is returned,
+   * so the search can continue with backtracking.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  template<typename StatisticType,
+           typename MatType,
+           template<typename BoundMetricType, typename...> class BoundType,
+           template<typename SplitBoundType, typename SplitMatType>
+               class SplitType>
+  double Rescore(
+      const size_t queryIndex,
+      tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
+          referenceNode,
+      const double oldScore) const;
+
+  /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursionm while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index e40d09e..4adcc29 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -144,9 +144,10 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+template<typename Tree>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     const size_t queryIndex,
-    TreeType& /* referenceNode */,
+    Tree& /* referenceNode */,
     const double oldScore) const
 {
   // If we are already pruning, still prune.
@@ -161,6 +162,34 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+template<typename StatisticType,
+         typename MatType,
+         template<typename BoundMetricType, typename...> class BoundType,
+         template<typename SplitBoundType, typename SplitMatType>
+             class SplitType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+    const size_t queryIndex,
+    tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
+        referenceNode,
+    double oldScore) const
+{
+  // If we are already pruning, still prune.
+  if (oldScore == DBL_MAX)
+    return oldScore;
+
+  if (referenceNode.Parent() && referenceNode.Parent()->Overlap())
+    // Defeatist search (If we have enough points, let's prune).
+    if (neighbors(neighbors.n_rows - 1, queryIndex) != (size_t() - 1))
+      return DBL_MAX;
+
+  // Just check the score again against the distances.
+  double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
+
+  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)




More information about the mlpack-git mailing list