[mlpack-git] master: Add NeighborSearchRules specialization for Spill Trees. (a599bf8)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:14 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit a599bf8b24cd11e6d324a70eb5f90805bf22e8d9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Wed Jul 27 00:02:18 2016 -0300

    Add NeighborSearchRules specialization for Spill Trees.


>---------------------------------------------------------------

a599bf8b24cd11e6d324a70eb5f90805bf22e8d9
 src/mlpack/methods/neighbor_search/CMakeLists.txt  |   2 +
 .../neighbor_search/neighbor_search_rules.hpp      |  30 +--
 .../neighbor_search/neighbor_search_rules_impl.hpp |  31 +--
 ...h_rules.hpp => neighbor_search_rules_spill.hpp} |  58 ++---
 ...pl.hpp => neighbor_search_rules_spill_impl.hpp} | 251 ++++++++++-----------
 5 files changed, 141 insertions(+), 231 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 1c51ce4..95fe37b 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -5,6 +5,8 @@ set(SOURCES
   neighbor_search_impl.hpp
   neighbor_search_rules.hpp
   neighbor_search_rules_impl.hpp
+  neighbor_search_rules_spill.hpp
+  neighbor_search_rules_spill_impl.hpp
   neighbor_search_stat.hpp
   ns_model.hpp
   ns_model_impl.hpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 82f5f57..8b73b2d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -9,7 +9,6 @@
 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
 
 #include <mlpack/core/tree/traversal_info.hpp>
-#include <mlpack/core/tree/spill_tree.hpp>
 
 namespace mlpack {
 namespace neighbor {
@@ -88,35 +87,11 @@ class NeighborSearchRules
    * @param referenceNode Candidate node to be recursed into.
    * @param oldScore Old score produced by Score() (or Rescore()).
    */
-  template<typename Tree>
   double Rescore(const size_t queryIndex,
-                 Tree& referenceNode,
+                 TreeType& referenceNode,
                  const double oldScore) const;
 
   /**
-   * Rescore function specialized for Spill Trees.  This function is used to
-   * update the score value when doing backtracking.  For spill trees, it
-   * implements a Hybrid sp-tree search.  If the parent node is a overlapping
-   * node and we have visited enough points, it decides to prune this node.
-   * If the parent node is a non-overlapping node, proper score is returned,
-   * so the search can continue with backtracking.
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param oldScore Old score produced by Score() (or Rescore()).
-   */
-  template<typename StatisticType,
-           typename MatType,
-           template<typename BoundMetricType, typename...> class BoundType,
-           template<typename SplitBoundType, typename SplitMatType>
-               class SplitType>
-  double Rescore(
-      const size_t queryIndex,
-      tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-          referenceNode,
-      const double oldScore) const;
-
-  /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursionm while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).
@@ -235,4 +210,7 @@ class NeighborSearchRules
 // Include implementation.
 #include "neighbor_search_rules_impl.hpp"
 
+// Include specialization for Spill Trees.
+#include "neighbor_search_rules_spill.hpp"
+
 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 4adcc29..e40d09e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -144,10 +144,9 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename Tree>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     const size_t queryIndex,
-    Tree& /* referenceNode */,
+    TreeType& /* referenceNode */,
     const double oldScore) const
 {
   // If we are already pruning, still prune.
@@ -162,34 +161,6 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename StatisticType,
-         typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
-    const size_t queryIndex,
-    tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-        referenceNode,
-    double oldScore) const
-{
-  // If we are already pruning, still prune.
-  if (oldScore == DBL_MAX)
-    return oldScore;
-
-  if (referenceNode.Parent() && referenceNode.Parent()->Overlap())
-    // Defeatist search (If we have enough points, let's prune).
-    if (neighbors(neighbors.n_rows - 1, queryIndex) != (size_t() - 1))
-      return DBL_MAX;
-
-  // Just check the score again against the distances.
-  double bestDistance = distances(distances.n_rows - 1, queryIndex);
-  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
similarity index 79%
copy from src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
copy to src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
index 82f5f57..6f84adf 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill.hpp
@@ -1,12 +1,13 @@
 /**
- * @file neighbor_search_rules.hpp
+ * @file neighbor_search_rules_spill.hpp
  * @author Ryan Curtin
+ * @author Marcos Pividori
  *
  * Defines the pruning rules and base case rules necessary to perform a
- * tree-based search (with an arbitrary tree) for the NeighborSearch class.
+ * tree-based search with Spill Trees for the NeighborSearch class.
  */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
 
 #include <mlpack/core/tree/traversal_info.hpp>
 #include <mlpack/core/tree/spill_tree.hpp>
@@ -15,19 +16,24 @@ namespace mlpack {
 namespace neighbor {
 
 /**
- * The NeighborSearchRules class is a template helper class used by
- * NeighborSearch class when performing distance-based neighbor searches.  For
- * each point in the query dataset, it keeps track of the k neighbors in the
- * reference dataset which have the 'best' distance according to a given sorting
- * policy.
+ * NeighborSearchRules specialization for Spill Trees.
+ * The main difference with the general implementation is that Score() methods
+ * consider the special case of a overlapping node.
  *
  * @tparam SortPolicy The sort policy for distances.
  * @tparam MetricType The metric to use for computation.
  * @tparam TreeType The tree type to use; must adhere to the TreeType API.
  */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-class NeighborSearchRules
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+    StatisticType, MatType, SplitType>>
 {
+  typedef tree::SpillTree<MetricType, StatisticType, MatType, SplitType>
+      TreeType;
  public:
   /**
    * Construct the NeighborSearchRules object.  This is usually done from within
@@ -88,35 +94,11 @@ class NeighborSearchRules
    * @param referenceNode Candidate node to be recursed into.
    * @param oldScore Old score produced by Score() (or Rescore()).
    */
-  template<typename Tree>
   double Rescore(const size_t queryIndex,
-                 Tree& referenceNode,
+                 TreeType& referenceNode,
                  const double oldScore) const;
 
   /**
-   * Rescore function specialized for Spill Trees.  This function is used to
-   * update the score value when doing backtracking.  For spill trees, it
-   * implements a Hybrid sp-tree search.  If the parent node is a overlapping
-   * node and we have visited enough points, it decides to prune this node.
-   * If the parent node is a non-overlapping node, proper score is returned,
-   * so the search can continue with backtracking.
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param oldScore Old score produced by Score() (or Rescore()).
-   */
-  template<typename StatisticType,
-           typename MatType,
-           template<typename BoundMetricType, typename...> class BoundType,
-           template<typename SplitBoundType, typename SplitMatType>
-               class SplitType>
-  double Rescore(
-      const size_t queryIndex,
-      tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-          referenceNode,
-      const double oldScore) const;
-
-  /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursionm while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).
@@ -233,6 +215,6 @@ class NeighborSearchRules
 } // namespace mlpack
 
 // Include implementation.
-#include "neighbor_search_rules_impl.hpp"
+#include "neighbor_search_rules_spill_impl.hpp"
 
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_SPILL_HPP
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
similarity index 69%
copy from src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
copy to src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
index 4adcc29..2ea54ba 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_spill_impl.hpp
@@ -1,20 +1,26 @@
 /**
- * @file neighbor_search_rules_impl.hpp
+ * @file neighbor_search_rules_spill_impl.hpp
  * @author Ryan Curtin
+ * @author Marcos Pividori
  *
- * Implementation of NeighborSearchRules.
+ * Implementation of NeighborSearchRules for Spill Trees.
  */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP
 
 // In case it hasn't been included yet.
-#include "neighbor_search_rules.hpp"
+#include "neighbor_search_rules_spill.hpp"
 
 namespace mlpack {
 namespace neighbor {
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+    StatisticType, MatType, SplitType>>::NeighborSearchRules(
     const typename TreeType::Mat& referenceSet,
     const typename TreeType::Mat& querySet,
     const size_t k,
@@ -53,8 +59,13 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     candidates.push_back(pqueue);
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+    StatisticType, MatType, SplitType>>::GetResults(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
@@ -73,69 +84,62 @@ void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
   }
 };
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
 inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
+    StatisticType, MatType, SplitType>>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
   if (sameSet && (queryIndex == referenceIndex))
     return 0.0;
 
-  // If we have already performed this base case, then do not perform it again.
-  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
-    return lastBaseCase;
-
   double distance = metric.Evaluate(querySet.col(queryIndex),
                                     referenceSet.col(referenceIndex));
   ++baseCases;
 
   InsertNeighbor(queryIndex, referenceIndex, distance);
 
-  // Cache this information for the next time BaseCase() is called.
-  lastQueryIndex = queryIndex;
-  lastReferenceIndex = referenceIndex;
-  lastBaseCase = distance;
-
   return distance;
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::Score(
     const size_t queryIndex,
     TreeType& referenceNode)
 {
   ++scores; // Count number of Score() calls.
-  double distance;
-  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
-  {
-    // The first point in the tree is the centroid.  So we can then calculate
-    // the base case between that and the query point.
-    double baseCase = -1.0;
-    if (tree::TreeTraits<TreeType>::HasSelfChildren)
-    {
-      // If the parent node is the same, then we have already calculated the
-      // base case.
-      if ((referenceNode.Parent() != NULL) &&
-          (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
-        baseCase = referenceNode.Parent()->Stat().LastDistance();
-      else
-        baseCase = BaseCase(queryIndex, referenceNode.Point(0));
-
-      // Save this evaluation.
-      referenceNode.Stat().LastDistance() = baseCase;
-    }
 
-    distance = SortPolicy::CombineBest(baseCase,
-        referenceNode.FurthestDescendantDistance());
-  }
-  else
+  if (!referenceNode.Parent())
+    return 0;
+
+  if (referenceNode.Parent()->Overlap()) // Defeatist search.
   {
-    distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
-        &referenceNode);
+    const double value = referenceNode.Parent()->SplitValue();
+    const size_t dim = referenceNode.Parent()->SplitDimension();
+    const bool left = &referenceNode == referenceNode.Parent()->Left();
+
+    if ((left && querySet(dim, queryIndex) <= value) ||
+        (!left && querySet(dim, queryIndex) > value))
+      return 0;
+    else
+      return DBL_MAX;
   }
 
+  double distance = SortPolicy::BestPointToNodeDistance(
+      querySet.col(queryIndex), &referenceNode);
+
   // Compare against the best k'th distance for this query point so far.
   double bestDistance = candidates[queryIndex].top().first;
   bestDistance = SortPolicy::Relax(bestDistance, epsilon);
@@ -143,12 +147,16 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename Tree>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::Rescore(
     const size_t queryIndex,
-    Tree& /* referenceNode */,
-    const double oldScore) const
+    TreeType& /* referenceNode */,
+    double oldScore) const
 {
   // If we are already pruning, still prune.
   if (oldScore == DBL_MAX)
@@ -161,40 +169,33 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
 template<typename StatisticType,
          typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
-    const size_t queryIndex,
-    tree::SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-        referenceNode,
-    double oldScore) const
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::Score(
+    TreeType& queryNode,
+    TreeType& referenceNode)
 {
-  // If we are already pruning, still prune.
-  if (oldScore == DBL_MAX)
-    return oldScore;
+  ++scores; // Count number of Score() calls
 
-  if (referenceNode.Parent() && referenceNode.Parent()->Overlap())
-    // Defeatist search (If we have enough points, let's prune).
-    if (neighbors(neighbors.n_rows - 1, queryIndex) != (size_t() - 1))
-      return DBL_MAX;
+  if (!referenceNode.Parent())
+    return 0;
 
-  // Just check the score again against the distances.
-  double bestDistance = distances(distances.n_rows - 1, queryIndex);
-  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
+  if (referenceNode.Parent()->Overlap()) // Defeatist search.
+  {
+    const double value = referenceNode.Parent()->SplitValue();
+    const size_t dim = referenceNode.Parent()->SplitDimension();
+    const bool left = &referenceNode == referenceNode.Parent()->Left();
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
-    TreeType& queryNode,
-    TreeType& referenceNode)
-{
-  ++scores; // Count number of Score() calls.
+    if ((left && queryNode.Bound()[dim].Lo() <= value) ||
+        (!left && queryNode.Bound()[dim].Hi() > value))
+      return 0;
+    else
+      return DBL_MAX;
+  }
 
   // Update our bound.
   const double bestDistance = CalculateBound(queryNode);
@@ -209,14 +210,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   const double score = traversalInfo.LastScore();
   double adjustedScore;
 
-  // We want to set adjustedScore to be the distance between the centroid of the
-  // last query node and last reference node.  We will do this by adjusting the
-  // last score.  In some cases, we can just use the last base case.
-  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
-  {
-    adjustedScore = traversalInfo.LastBaseCase();
-  }
-  else if (score == 0.0) // Nothing we can do here.
+  if (score == 0.0) // Nothing we can do here.
   {
     adjustedScore = 0.0;
   }
@@ -289,48 +283,14 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   // Can we prune?
   if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
   {
-    if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
-    {
-      // There isn't any need to set the traversal information because no
-      // descendant combinations will be visited, and those are the only
-      // combinations that would depend on the traversal information.
-      return DBL_MAX;
-    }
+    // There isn't any need to set the traversal information because no
+    // descendant combinations will be visited, and those are the only
+    // combinations that would depend on the traversal information.
+    return DBL_MAX;
   }
 
-  double distance;
-  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
-  {
-    // The first point in the node is the centroid, so we can calculate the
-    // distance between the two points using BaseCase() and then find the
-    // bounds.  This is potentially loose for non-ball bounds.
-    double baseCase = -1.0;
-    if (tree::TreeTraits<TreeType>::HasSelfChildren &&
-       (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
-       (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
-    {
-      // We already calculated it.
-      baseCase = traversalInfo.LastBaseCase();
-    }
-    else
-    {
-      baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
-    }
-
-    distance = SortPolicy::CombineBest(baseCase,
-        queryNode.FurthestDescendantDistance() +
-        referenceNode.FurthestDescendantDistance());
-
-    lastQueryIndex = queryNode.Point(0);
-    lastReferenceIndex = referenceNode.Point(0);
-    lastBaseCase = baseCase;
-
-    traversalInfo.LastBaseCase() = baseCase;
-  }
-  else
-  {
-    distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
-  }
+  double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+      &referenceNode);
 
   if (SortPolicy::IsBetter(distance, bestDistance))
   {
@@ -350,8 +310,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   }
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::Rescore(
     TreeType& queryNode,
     TreeType& /* referenceNode */,
     const double oldScore) const
@@ -359,6 +324,9 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
   if (oldScore == DBL_MAX)
     return oldScore;
 
+  if (oldScore == 0)
+    return oldScore;
+
   // Update our bound.
   const double bestDistance = CalculateBound(queryNode);
 
@@ -367,8 +335,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
 
 // Calculate the bound for a given query node in its current state and update
 // it.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::
     CalculateBound(TreeType& queryNode) const
 {
   // This is an adapted form of the B(N_q) function in the paper
@@ -490,9 +463,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-InsertNeighbor(
+template<typename StatisticType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType,
+         typename SortPolicy,
+         typename MetricType>
+inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
+    MetricType, StatisticType, MatType, SplitType>>::InsertNeighbor(
     const size_t queryIndex,
     const size_t neighbor,
     const double distance)
@@ -510,4 +487,4 @@ InsertNeighbor(
 } // namespace neighbor
 } // namespace mlpack
 
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_SPILL_IMPL_HPP




More information about the mlpack-git mailing list