[mlpack-git] master: Remove specialization of NeighborSearchRules for SpillTrees. (e4ce9be)

gitdub at mlpack.org gitdub at mlpack.org
Wed Aug 17 02:33:04 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit e4ce9be6795817f7ceebafc26dd2b24c38a356f9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Wed Aug 17 03:33:04 2016 -0300

    Remove specialization of NeighborSearchRules for SpillTrees.


>---------------------------------------------------------------

e4ce9be6795817f7ceebafc26dd2b24c38a356f9
 .../spill_tree/spill_dual_tree_traverser_impl.hpp  | 349 ++++++++--------
 .../spill_single_tree_traverser_impl.hpp           |  22 +-
 src/mlpack/methods/neighbor_search/CMakeLists.txt  |   2 -
 .../neighbor_search/neighbor_search_rules.hpp      |   6 +-
 .../methods/neighbor_search/spill_search_rules.hpp | 229 -----------
 .../neighbor_search/spill_search_rules_impl.hpp    | 440 ---------------------
 6 files changed, 195 insertions(+), 853 deletions(-)

diff --git a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
index c6deca2..6853b0d 100644
--- a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
@@ -102,80 +102,91 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
   }
   else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
   {
-    // We have to recurse down the reference node.  In this case the recursion
-    // order does matter.  Before recursing, though, we have to set the
-    // traversal information correctly.
-    double leftScore = rule.Score(queryNode, *referenceNode.Left());
-    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
-    rule.TraversalInfo() = traversalInfo;
-    double rightScore = rule.Score(queryNode, *referenceNode.Right());
-    numScores += 2;
-
-    if (leftScore < rightScore)
+    if (Defeatist && referenceNode.Overlap())
     {
-      // Recurse to the left.  Restore the left traversal info.  Store the right
-      // traversal info.
-      traversalInfo = rule.TraversalInfo();
-      rule.TraversalInfo() = leftInfo;
-      Traverse(queryNode, *referenceNode.Left());
+      // If referenceNode is a overlapping node let's do defeatist search.
+      bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(queryNode);
+      bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+          queryNode);
+      if (traverseLeft && !traverseRight)
+        Traverse(queryNode, *referenceNode.Left());
+      else if (!traverseLeft && traverseRight)
+        Traverse(queryNode, *referenceNode.Right());
+      else
+      {
+        // If we can't decide which child node to traverse, this means that
+        // queryNode is at both sides of the splitting hyperplane. So, as
+        // queryNode is a leafNode, all we can do is single tree search for each
+        // point in the query node.
+        const size_t queryEnd = queryNode.NumPoints();
+        DefeatistSingleTreeTraverser<RuleType> st(rule);
+        // Loop through each of the points in query node.
+        for (size_t query = 0; query < queryEnd; ++query)
+        {
+          const size_t queryIndex = queryNode.Point(query);
+          // See if we need to investigate this point.
+          const double childScore = rule.Score(queryIndex, referenceNode);
 
-      // Is it still valid to recurse to the right?
-      rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+          if (childScore == DBL_MAX)
+            continue; // We can't improve this particular point.
 
-      if (rightScore != DBL_MAX)
-      {
-        // Restore the right traversal info.
-        rule.TraversalInfo() = traversalInfo;
-        Traverse(queryNode, *referenceNode.Right());
+          st.Traverse(queryIndex, referenceNode);
+        }
       }
-      else
-        ++numPrunes;
     }
-    else if (rightScore < leftScore)
+    else
     {
-      // Recurse to the right.
-      Traverse(queryNode, *referenceNode.Right());
-
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+      // We have to recurse down the reference node.  In this case the recursion
+      // order does matter.  Before recursing, though, we have to set the
+      // traversal information correctly.
+      double leftScore = rule.Score(queryNode, *referenceNode.Left());
+      typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = traversalInfo;
+      double rightScore = rule.Score(queryNode, *referenceNode.Right());
+      numScores += 2;
 
-      if (leftScore != DBL_MAX)
+      if (leftScore < rightScore)
       {
-        // Restore the left traversal info.
+        // Recurse to the left.  Restore the left traversal info.  Store the right
+        // traversal info.
+        traversalInfo = rule.TraversalInfo();
         rule.TraversalInfo() = leftInfo;
         Traverse(queryNode, *referenceNode.Left());
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          rule.TraversalInfo() = traversalInfo;
+          Traverse(queryNode, *referenceNode.Right());
+        }
+        else
+          ++numPrunes;
       }
-      else
-        ++numPrunes;
-    }
-    else // leftScore is equal to rightScore.
-    {
-      if (leftScore == DBL_MAX)
+      else if (rightScore < leftScore)
       {
-        numPrunes += 2;
+        // Recurse to the right.
+        Traverse(queryNode, *referenceNode.Right());
+
+        // Is it still valid to recurse to the left?
+        leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+        if (leftScore != DBL_MAX)
+        {
+          // Restore the left traversal info.
+          rule.TraversalInfo() = leftInfo;
+          Traverse(queryNode, *referenceNode.Left());
+        }
+        else
+          ++numPrunes;
       }
-      else
+      else // leftScore is equal to rightScore.
       {
-        if (Defeatist && referenceNode.Overlap())
+        if (leftScore == DBL_MAX)
         {
-          // If referenceNode is a overlapping node and we can't decide which
-          // child node to traverse, this means that queryNode is at both sides
-          // of the splitting hyperplane. So, as queryNode is a leafNode, all we
-          // can do is single tree search for each point in the query node.
-          const size_t queryEnd = queryNode.NumPoints();
-          SingleTreeTraverser<RuleType> st(rule);
-          // Loop through each of the points in query node.
-          for (size_t query = 0; query < queryEnd; ++query)
-          {
-            const size_t queryIndex = queryNode.Point(query);
-            // See if we need to investigate this point.
-            const double childScore = rule.Score(queryIndex, referenceNode);
-
-            if (childScore == DBL_MAX)
-              continue; // We can't improve this particular point.
-
-            st.Traverse(queryIndex, referenceNode);
-          }
+          numPrunes += 2;
         }
         else
         {
@@ -202,71 +213,98 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
   }
   else
   {
-    // We have to recurse down both query and reference nodes.  Because the
-    // query descent order does not matter, we will go to the left query child
-    // first.  Before recursing, we have to set the traversal information
-    // correctly.
-    double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
-    typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
-    rule.TraversalInfo() = traversalInfo;
-    double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
-    typename RuleType::TraversalInfoType rightInfo;
-    numScores += 2;
-
-    if (leftScore < rightScore)
+    if (Defeatist && referenceNode.Overlap())
     {
-      // Recurse to the left.  Restore the left traversal info.  Store the right
-      // traversal info.
-      rightInfo = rule.TraversalInfo();
-      rule.TraversalInfo() = leftInfo;
-      Traverse(*queryNode.Left(), *referenceNode.Left());
-
-      // Is it still valid to recurse to the right?
-      rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
-          rightScore);
-
-      if (rightScore != DBL_MAX)
-      {
-        // Restore the right traversal info.
-        rule.TraversalInfo() = rightInfo;
+      // If referenceNode is a overlapping node let's do defeatist search.
+      bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
+          *queryNode.Left());
+      bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+          *queryNode.Left());
+      if (traverseLeft && !traverseRight)
+        Traverse(*queryNode.Left(), *referenceNode.Left());
+      else if (!traverseLeft && traverseRight)
         Traverse(*queryNode.Left(), *referenceNode.Right());
+      else
+      {
+        // If we can't decide which child node to traverse, this means that
+        // queryNode.Left() is at both sides of the splitting hyperplane. So,
+        // let's recurse down only the query node.
+        Traverse(*queryNode.Left(), referenceNode);
       }
+
+      traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
+          *queryNode.Right());
+      traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+          *queryNode.Right());
+      if (traverseLeft && !traverseRight)
+        Traverse(*queryNode.Right(), *referenceNode.Left());
+      else if (!traverseLeft && traverseRight)
+        Traverse(*queryNode.Right(), *referenceNode.Right());
       else
-        ++numPrunes;
+      {
+        // If we can't decide which child node to traverse, this means that
+        // queryNode.Right() is at both sides of the splitting hyperplane. So,
+        // let's recurse down only the query node.
+        Traverse(*queryNode.Right(), referenceNode);
+      }
     }
-    else if (rightScore < leftScore)
+    else
     {
-      // Recurse to the right.
-      Traverse(*queryNode.Left(), *referenceNode.Right());
-
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
-          leftScore);
+      // We have to recurse down both query and reference nodes.  Because the
+      // query descent order does not matter, we will go to the left query child
+      // first.  Before recursing, we have to set the traversal information
+      // correctly.
+      double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+      typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = traversalInfo;
+      double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+      typename RuleType::TraversalInfoType rightInfo;
+      numScores += 2;
 
-      if (leftScore != DBL_MAX)
+      if (leftScore < rightScore)
       {
-        // Restore the left traversal info.
+        // Recurse to the left.  Restore the left traversal info.  Store the right
+        // traversal info.
+        rightInfo = rule.TraversalInfo();
         rule.TraversalInfo() = leftInfo;
         Traverse(*queryNode.Left(), *referenceNode.Left());
+
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+            rightScore);
+
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          rule.TraversalInfo() = rightInfo;
+          Traverse(*queryNode.Left(), *referenceNode.Right());
+        }
+        else
+          ++numPrunes;
       }
-      else
-        ++numPrunes;
-    }
-    else
-    {
-      if (leftScore == DBL_MAX)
+      else if (rightScore < leftScore)
       {
-        numPrunes += 2;
+        // Recurse to the right.
+        Traverse(*queryNode.Left(), *referenceNode.Right());
+
+        // Is it still valid to recurse to the left?
+        leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+            leftScore);
+
+        if (leftScore != DBL_MAX)
+        {
+          // Restore the left traversal info.
+          rule.TraversalInfo() = leftInfo;
+          Traverse(*queryNode.Left(), *referenceNode.Left());
+        }
+        else
+          ++numPrunes;
       }
       else
       {
-        if (Defeatist && referenceNode.Overlap())
+        if (leftScore == DBL_MAX)
         {
-          // If referenceNode is a overlapping node and we can't decide which
-          // child node to traverse, this means that queryNode.Left() is at both
-          // sides of the splitting hyperplane. So, let's recurse down only the
-          // query node.
-          Traverse(*queryNode.Left(), referenceNode);
+          numPrunes += 2;
         }
         else
         {
@@ -290,72 +328,61 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
             ++numPrunes;
         }
       }
-    }
 
-    // Restore the main traversal information.
-    rule.TraversalInfo() = traversalInfo;
+      // Restore the main traversal information.
+      rule.TraversalInfo() = traversalInfo;
 
-    // Now recurse down the right query node.
-    leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
-    leftInfo = rule.TraversalInfo();
-    rule.TraversalInfo() = traversalInfo;
-    rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
-    numScores += 2;
+      // Now recurse down the right query node.
+      leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+      leftInfo = rule.TraversalInfo();
+      rule.TraversalInfo() = traversalInfo;
+      rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+      numScores += 2;
 
-    if (leftScore < rightScore)
-    {
-      // Recurse to the left.  Restore the left traversal info.  Store the right
-      // traversal info.
-      rightInfo = rule.TraversalInfo();
-      rule.TraversalInfo() = leftInfo;
-      Traverse(*queryNode.Right(), *referenceNode.Left());
+      if (leftScore < rightScore)
+      {
+        // Recurse to the left.  Restore the left traversal info.  Store the right
+        // traversal info.
+        rightInfo = rule.TraversalInfo();
+        rule.TraversalInfo() = leftInfo;
+        Traverse(*queryNode.Right(), *referenceNode.Left());
 
-      // Is it still valid to recurse to the right?
-      rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
-          rightScore);
+        // Is it still valid to recurse to the right?
+        rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+            rightScore);
 
-      if (rightScore != DBL_MAX)
+        if (rightScore != DBL_MAX)
+        {
+          // Restore the right traversal info.
+          rule.TraversalInfo() = rightInfo;
+          Traverse(*queryNode.Right(), *referenceNode.Right());
+        }
+        else
+          ++numPrunes;
+      }
+      else if (rightScore < leftScore)
       {
-        // Restore the right traversal info.
-        rule.TraversalInfo() = rightInfo;
+        // Recurse to the right.
         Traverse(*queryNode.Right(), *referenceNode.Right());
-      }
-      else
-        ++numPrunes;
-    }
-    else if (rightScore < leftScore)
-    {
-      // Recurse to the right.
-      Traverse(*queryNode.Right(), *referenceNode.Right());
 
-      // Is it still valid to recurse to the left?
-      leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
-          leftScore);
+        // Is it still valid to recurse to the left?
+        leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+            leftScore);
 
-      if (leftScore != DBL_MAX)
-      {
-        // Restore the left traversal info.
-        rule.TraversalInfo() = leftInfo;
-        Traverse(*queryNode.Right(), *referenceNode.Left());
-      }
-      else
-        ++numPrunes;
-    }
-    else
-    {
-      if (leftScore == DBL_MAX)
-      {
-        numPrunes += 2;
+        if (leftScore != DBL_MAX)
+        {
+          // Restore the left traversal info.
+          rule.TraversalInfo() = leftInfo;
+          Traverse(*queryNode.Right(), *referenceNode.Left());
+        }
+        else
+          ++numPrunes;
       }
       else
       {
-        if (Defeatist && referenceNode.Overlap())
+        if (leftScore == DBL_MAX)
         {
-          // If referenceNode is a overlapping node and we can't decide which
-          // child node to traverse, this means that queryNode.Right() is at
-          // both sides of the splitting hyperplane. So, let's recurse down only
-          // the query node.
-          Traverse(*queryNode.Right(), referenceNode);
+          numPrunes += 2;
         }
         else
         {
diff --git a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
index 4ce5a22..0d28969 100644
--- a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
@@ -53,27 +53,13 @@ SpillSingleTreeTraverser<RuleType, Defeatist>::Traverse(
   {
     if (Defeatist && referenceNode.Overlap())
     {
-      // If referenceNode is a overlapping node we do defeatist search. In this
-      // case, it is enough to calculate the score of only one child node. As we
-      // know that the query point can't be at both sides of the splitting
-      // hyperplane, the possible scores for the references child nodes are:
-      // 0 or DBL_MAX.
-      double leftScore = rule.Score(queryIndex, *referenceNode.Left());
-
-      if (leftScore == 0)
-      {
-        // Recurse to the left.
+      // If referenceNode is a overlapping node we do defeatist search.
+      if (referenceNode.Left()->HalfSpaceContains(
+          rule.QuerySet().col(queryIndex)))
         Traverse(queryIndex, *referenceNode.Left());
-        // Prune the right node.
-        ++numPrunes;
-      }
       else
-      {
-        // Recurse to the right.
         Traverse(queryIndex, *referenceNode.Right());
-        // Prune the left node.
-        ++numPrunes;
-      }
+      ++numPrunes;
     }
     else
     {
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 0903508..e66b5a7 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -14,8 +14,6 @@ set(SOURCES
   sort_policies/furthest_neighbor_sort_impl.hpp
   spill_search.hpp
   spill_search_impl.hpp
-  spill_search_rules.hpp
-  spill_search_rules_impl.hpp
   typedef.hpp
   unmap.hpp
   unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index b16a9ea..e7a7ce1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -134,6 +134,9 @@ class NeighborSearchRules
   //! Modify the traversal info.
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
 
+  //! Access the query set.
+  const typename TreeType::Mat& QuerySet() { return querySet; }
+
  protected:
   //! The reference set.
   const typename TreeType::Mat& referenceSet;
@@ -210,7 +213,4 @@ class NeighborSearchRules
 // Include implementation.
 #include "neighbor_search_rules_impl.hpp"
 
-// Include specialization for Spill Trees.
-#include "spill_search_rules.hpp"
-
 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
deleted file mode 100644
index a87578c..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
+++ /dev/null
@@ -1,229 +0,0 @@
-/**
- * @file spill_search_rules.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Defines the pruning rules and base case rules necessary to perform a
- * tree-based search with Spill Trees for the NeighborSearch class.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
-
-#include <mlpack/core/tree/traversal_info.hpp>
-#include <mlpack/core/tree/spill_tree.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * NeighborSearchRules specialization for Spill Trees.
- * The main difference with the general implementation is that Score() methods
- * consider the special case of a overlapping node.
- * Also, CalculateBound() only considers B_1 bound, because we can not use B_2
- * with spill trees.
- *
- * @tparam SortPolicy The sort policy for distances.
- * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use; must adhere to the TreeType API.
- */
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, HyperplaneType, SplitType>>
-{
-  typedef tree::SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
-      SplitType> TreeType;
- public:
-  /**
-   * Construct the NeighborSearchRules object.  This is usually done from within
-   * the NeighborSearch class at search time.
-   *
-   * @param referenceSet Set of reference data.
-   * @param querySet Set of query data.
-   * @param k Number of neighbors to search for.
-   * @param metric Instantiated metric.
-   * @param epsilon Relative approximate error.
-   * @param sameSet If true, the query and reference set are taken to be the
-   *      same, and a query point will not return itself in the results.
-   */
-  NeighborSearchRules(const typename TreeType::Mat& referenceSet,
-                      const typename TreeType::Mat& querySet,
-                      const size_t k,
-                      MetricType& metric,
-                      const double epsilon = 0,
-                      const bool sameSet = false);
-
-  /**
-   * Store the list of candidates for each query point in the given matrices.
-   *
-   * @param neighbors Matrix storing lists of neighbors for each query point.
-   * @param distances Matrix storing distances of neighbors for each query
-   *     point.
-   */
-  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
-
-  /**
-   * Get the distance from the query point to the reference point.
-   * This will update the list of candidates with the new point if appropriate
-   * and will track the number of base cases (number of points evaluated).
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceIndex Index of reference point.
-   */
-  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
-  /**
-   * Get the score for recursion order.  It implements a Hybrid sp-tree
-   * search.  If referenceNode's parent is a overlapping node, the score is
-   * calculated based on the splitting hyperplane: if query point is on the same
-   * side, returns 0, else DBL_MAX.
-   * If referenceNode's parent is a non-overlapping node, proper score is
-   * calculated, similar to the general Score() method.
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   */
-  double Score(const size_t queryIndex, TreeType& referenceNode);
-
-  /**
-   * Re-evaluate the score for recursion order.  A low score indicates priority
-   * for recursion, while DBL_MAX indicates that the node should not be recursed
-   * into at all (it should be pruned).  This is used when the score has already
-   * been calculated, but another recursion may have modified the bounds for
-   * pruning.  So the old score is checked against the new pruning bound.
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param oldScore Old score produced by Score() (or Rescore()).
-   */
-  double Rescore(const size_t queryIndex,
-                 TreeType& referenceNode,
-                 const double oldScore) const;
-
-  /**
-   * Get the score for recursion order.  It implements a Hybrid sp-tree
-   * search.  If referenceNode's parent is a overlapping node, the score is
-   * calculated based on the splitting hyperplane: if queryNode's bound
-   * intersects the referenceNode's half space, returns 0, else DBL_MAX.
-   * If referenceNode's parent is a non-overlapping node, proper score is
-   * calculated, similar to the general Score() method.
-   *
-   * @param queryNode Candidate query node to recurse into.
-   * @param referenceNode Candidate reference node to recurse into.
-   */
-  double Score(TreeType& queryNode, TreeType& referenceNode);
-
-  /**
-   * Re-evaluate the score for recursion order.  A low score indicates priority
-   * for recursion, while DBL_MAX indicates that the node should not be recursed
-   * into at all (it should be pruned).  This is used when the score has already
-   * been calculated, but another recursion may have modified the bounds for
-   * pruning.  So the old score is checked against the new pruning bound.
-   *
-   * @param queryNode Candidate query node to recurse into.
-   * @param referenceNode Candidate reference node to recurse into.
-   * @param oldScore Old score produced by Socre() (or Rescore()).
-   */
-  double Rescore(TreeType& queryNode,
-                 TreeType& referenceNode,
-                 const double oldScore) const;
-
-  //! Get the number of base cases that have been performed.
-  size_t BaseCases() const { return baseCases; }
-  //! Modify the number of base cases that have been performed.
-  size_t& BaseCases() { return baseCases; }
-
-  //! Get the number of scores that have been performed.
-  size_t Scores() const { return scores; }
-  //! Modify the number of scores that have been performed.
-  size_t& Scores() { return scores; }
-
-  //! Convenience typedef.
-  typedef typename tree::TraversalInfo<TreeType> TraversalInfoType;
-
-  //! Get the traversal info.
-  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
-  //! Modify the traversal info.
-  TraversalInfoType& TraversalInfo() { return traversalInfo; }
-
- protected:
-  //! The reference set.
-  const typename TreeType::Mat& referenceSet;
-
-  //! The query set.
-  const typename TreeType::Mat& querySet;
-
-  //! Candidate represents a possible candidate neighbor (distance, index).
-  typedef std::pair<double, size_t> Candidate;
-
-  //! Compare two candidates based on the distance.
-  struct CandidateCmp {
-    bool operator()(const Candidate& c1, const Candidate& c2)
-    {
-      return !SortPolicy::IsBetter(c2.first, c1.first);
-    };
-  };
-
-  //! Use a priority queue to represent the list of candidate neighbors.
-  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
-      CandidateList;
-
-  //! Set of candidate neighbors for each point.
-  std::vector<CandidateList> candidates;
-
-  //! Number of neighbors to search for.
-  const size_t k;
-
-  //! The instantiated metric.
-  MetricType& metric;
-
-  //! Denotes whether or not the reference and query sets are the same.
-  bool sameSet;
-
-  //! Relative error to be considered in approximate search.
-  const double epsilon;
-
-  //! The last query point BaseCase() was called with.
-  size_t lastQueryIndex;
-  //! The last reference point BaseCase() was called with.
-  size_t lastReferenceIndex;
-  //! The last base case result.
-  double lastBaseCase;
-
-  //! The number of base cases that have been performed.
-  size_t baseCases;
-  //! The number of scores that have been performed.
-  size_t scores;
-
-  //! Traversal info for the parent combination; this is updated by the
-  //! traversal before each call to Score().
-  TraversalInfoType traversalInfo;
-
-  /**
-   * Recalculate the bound for a given query node.
-   */
-  double CalculateBound(TreeType& queryNode) const;
-
-  /**
-   * Helper function to insert a point into the list of candidate points.
-   *
-   * @param queryIndex Index of point whose neighbors we are inserting into.
-   * @param neighbor Index of reference point which is being inserted.
-   * @param distance Distance from query point to reference point.
-   */
-  void InsertNeighbor(const size_t queryIndex,
-                      const size_t neighbor,
-                      const double distance);
-};
-
-} // namespace neighbor
-} // namespace mlpack
-
-// Include implementation.
-#include "spill_search_rules_impl.hpp"
-
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
deleted file mode 100644
index 50a61ac..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
+++ /dev/null
@@ -1,440 +0,0 @@
-/**
- * @file spill_search_rules_impl.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Implementation of NeighborSearchRules for Spill Trees.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "spill_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, HyperplaneType, SplitType>>::NeighborSearchRules(
-    const typename TreeType::Mat& referenceSet,
-    const typename TreeType::Mat& querySet,
-    const size_t k,
-    MetricType& metric,
-    const double epsilon,
-    const bool sameSet) :
-    referenceSet(referenceSet),
-    querySet(querySet),
-    k(k),
-    metric(metric),
-    sameSet(sameSet),
-    epsilon(epsilon),
-    lastQueryIndex(querySet.n_cols),
-    lastReferenceIndex(referenceSet.n_cols),
-    baseCases(0),
-    scores(0)
-{
-  // We must set the traversal info last query and reference node pointers to
-  // something that is both invalid (i.e. not a tree node) and not NULL.  We'll
-  // use the this pointer.
-  traversalInfo.LastQueryNode() = (TreeType*) this;
-  traversalInfo.LastReferenceNode() = (TreeType*) this;
-
-  // Let's build the list of candidate neighbors for each query point.
-  // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
-  // The list of candidates will be updated when visiting new points with the
-  // BaseCase() method.
-  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
-      size_t() - 1);
-
-  std::vector<Candidate> vect(k, def);
-  CandidateList pqueue(CandidateCmp(), std::move(vect));
-
-  candidates.reserve(querySet.n_cols);
-  for (size_t i = 0; i < querySet.n_cols; i++)
-    candidates.push_back(pqueue);
-}
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, HyperplaneType, SplitType>>::GetResults(
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
-{
-  neighbors.set_size(k, querySet.n_cols);
-  distances.set_size(k, querySet.n_cols);
-
-  for (size_t i = 0; i < querySet.n_cols; i++)
-  {
-    CandidateList& pqueue = candidates[i];
-    for (size_t j = 1; j <= k; j++)
-    {
-      neighbors(k - j, i) = pqueue.top().second;
-      distances(k - j, i) = pqueue.top().first;
-      pqueue.pop();
-    }
-  }
-};
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, HyperplaneType, SplitType>>::BaseCase(
-    const size_t queryIndex,
-    const size_t referenceIndex)
-{
-  // If the datasets are the same, then this search is only using one dataset
-  // and we should not return identical points.
-  if (sameSet && (queryIndex == referenceIndex))
-    return 0.0;
-
-  double distance = metric.Evaluate(querySet.col(queryIndex),
-                                    referenceSet.col(referenceIndex));
-  ++baseCases;
-
-  InsertNeighbor(queryIndex, referenceIndex, distance);
-
-  return distance;
-}
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
-    const size_t queryIndex,
-    TreeType& referenceNode)
-{
-  ++scores; // Count number of Score() calls.
-
-  if (!referenceNode.Parent())
-    return 0;
-
-  if (referenceNode.Parent()->Overlap()) // Defeatist search.
-  {
-    if (referenceNode.HalfSpaceContains(querySet.col(queryIndex)))
-      return 0;
-    else
-      return DBL_MAX;
-  }
-
-  double distance = SortPolicy::BestPointToNodeDistance(
-      querySet.col(queryIndex), &referenceNode);
-
-  // Compare against the best k'th distance for this query point so far.
-  double bestDistance = candidates[queryIndex].top().first;
-  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
-  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
-    const size_t queryIndex,
-    TreeType& /* referenceNode */,
-    double oldScore) const
-{
-  // If we are already pruning, still prune.
-  if (oldScore == DBL_MAX)
-    return oldScore;
-
-  // Just check the score again against the distances.
-  double bestDistance = candidates[queryIndex].top().first;
-  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
-    TreeType& queryNode,
-    TreeType& referenceNode)
-{
-  ++scores; // Count number of Score() calls
-
-  if (!referenceNode.Parent())
-    return 0;
-
-  if (referenceNode.Parent()->Overlap()) // Defeatist search.
-  {
-    if (referenceNode.HalfSpaceIntersects(queryNode))
-      return 0;
-    else
-      return DBL_MAX;
-  }
-
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
-
-  // Use the traversal info to see if a parent-child or parent-parent prune is
-  // possible.  This is a looser bound than we could make, but it might be
-  // sufficient.
-  const double queryParentDist = queryNode.ParentDistance();
-  const double queryDescDist = queryNode.FurthestDescendantDistance();
-  const double refParentDist = referenceNode.ParentDistance();
-  const double refDescDist = referenceNode.FurthestDescendantDistance();
-  const double score = traversalInfo.LastScore();
-  double adjustedScore;
-
-  if (score == 0.0) // Nothing we can do here.
-  {
-    adjustedScore = 0.0;
-  }
-  else
-  {
-    // The last score is equal to the distance between the centroids minus the
-    // radii of the query and reference bounds along the axis of the line
-    // between the two centroids.  In the best case, these radii are the
-    // furthest descendant distances, but that is not always true.  It would
-    // take too long to calculate the exact radii, so we are forced to use
-    // MinimumBoundDistance() as a lower-bound approximation.
-    const double lastQueryDescDist =
-        traversalInfo.LastQueryNode()->MinimumBoundDistance();
-    const double lastRefDescDist =
-        traversalInfo.LastReferenceNode()->MinimumBoundDistance();
-    adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
-    adjustedScore = SortPolicy::CombineWorst(adjustedScore, lastRefDescDist);
-  }
-
-  // Assemble an adjusted score.  For nearest neighbor search, this adjusted
-  // score is a lower bound on MinDistance(queryNode, referenceNode) that is
-  // assembled without actually calculating MinDistance().  For furthest
-  // neighbor search, it is an upper bound on
-  // MaxDistance(queryNode, referenceNode).  If the traversalInfo isn't usable
-  // then the node should not be pruned by this.
-  if (traversalInfo.LastQueryNode() == queryNode.Parent())
-  {
-    const double queryAdjust = queryParentDist + queryDescDist;
-    adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
-  }
-  else if (traversalInfo.LastQueryNode() == &queryNode)
-  {
-    adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
-  }
-  else
-  {
-    // The last query node wasn't this query node or its parent.  So we force
-    // the adjustedScore to be such that this combination can't be pruned here,
-    // because we don't really know anything about it.
-
-    // It would be possible to modify this section to try and make a prune based
-    // on the query descendant distance and the distance between the query node
-    // and last traversal query node, but this case doesn't actually happen for
-    // kd-trees or cover trees.
-    adjustedScore = SortPolicy::BestDistance();
-  }
-
-  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
-  {
-    const double refAdjust = refParentDist + refDescDist;
-    adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
-  }
-  else if (traversalInfo.LastReferenceNode() == &referenceNode)
-  {
-    adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
-  }
-  else
-  {
-    // The last reference node wasn't this reference node or its parent.  So we
-    // force the adjustedScore to be such that this combination can't be pruned
-    // here, because we don't really know anything about it.
-
-    // It would be possible to modify this section to try and make a prune based
-    // on the reference descendant distance and the distance between the
-    // reference node and last traversal reference node, but this case doesn't
-    // actually happen for kd-trees or cover trees.
-    adjustedScore = SortPolicy::BestDistance();
-  }
-
-  // Can we prune?
-  if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
-  {
-    // There isn't any need to set the traversal information because no
-    // descendant combinations will be visited, and those are the only
-    // combinations that would depend on the traversal information.
-    return DBL_MAX;
-  }
-
-  double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
-      &referenceNode);
-
-  if (SortPolicy::IsBetter(distance, bestDistance))
-  {
-    // Set traversal information.
-    traversalInfo.LastQueryNode() = &queryNode;
-    traversalInfo.LastReferenceNode() = &referenceNode;
-    traversalInfo.LastScore() = distance;
-
-    return distance;
-  }
-  else
-  {
-    // There isn't any need to set the traversal information because no
-    // descendant combinations will be visited, and those are the only
-    // combinations that would depend on the traversal information.
-    return DBL_MAX;
-  }
-}
-
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
-    TreeType& queryNode,
-    TreeType& /* referenceNode */,
-    const double oldScore) const
-{
-  if (oldScore == DBL_MAX)
-    return oldScore;
-
-  if (oldScore == SortPolicy::BestDistance())
-    return oldScore;
-
-  // Update our bound.
-  const double bestDistance = CalculateBound(queryNode);
-
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-// Calculate the bound for a given query node in its current state and update
-// it.
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
-    CalculateBound(TreeType& queryNode) const
-{
-  // This is an adapted form of the B(N_q) function in the paper
-  // ``Tree-Independent Dual-Tree Algorithms'' by Curtin et. al.; the goal is to
-  // place a bound on the worst possible distance a point combination could have
-  // to improve any of the current neighbor estimates.  If the best possible
-  // distance between two nodes is greater than this bound, then the node
-  // combination can be pruned (see Score()).
-
-  // There are a couple ways we can assemble a bound.  For simplicity, this is
-  // described for nearest neighbor search (SortPolicy = NearestNeighborSort),
-  // but the code that is written is adapted for whichever SortPolicy.
-
-  // First, we can consider the current worst neighbor candidate distance of any
-  // descendant point.  This is assembled with 'worstDistance' by looping
-  // through the points held by the query node, and then by taking the cached
-  // worst distance from any child nodes (Stat().FirstBound()).  This
-  // corresponds roughly to B_1(N_q) in the paper.
-
-  double worstDistance = SortPolicy::BestDistance();
-
-  // Loop over points held in the node.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    const double distance = candidates[queryNode.Point(i)].top().first;
-    if (SortPolicy::IsBetter(worstDistance, distance))
-      worstDistance = distance;
-  }
-
-  // Loop over children of the node, and use their cached information to
-  // assemble bounds.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    const double firstBound = queryNode.Child(i).Stat().FirstBound();
-
-    if (SortPolicy::IsBetter(worstDistance, firstBound))
-      worstDistance = firstBound;
-  }
-
-  // At this point, worstDistance holds the value of B_1(N_q).
-
-  // Now consider the parent bounds.
-  if (queryNode.Parent() != NULL)
-  {
-    // The parent's worst distance bound implies that the bound for this node
-    // must be at least as good.  Thus, if the parent worst distance bound is
-    // better, then take it.
-    if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
-        worstDistance))
-      worstDistance = queryNode.Parent()->Stat().FirstBound();
-  }
-
-  // Could the existing bounds be better?
-  if (SortPolicy::IsBetter(queryNode.Stat().FirstBound(), worstDistance))
-    worstDistance = queryNode.Stat().FirstBound();
-
-  // Cache bounds for later.
-  queryNode.Stat().FirstBound() = worstDistance;
-
-  worstDistance = SortPolicy::Relax(worstDistance, epsilon);
-
-  return worstDistance;
-}
-
-/**
- * Helper function to insert a point into the list of candidate points.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType,
-         typename SortPolicy,
-         typename MetricType>
-inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
-    InsertNeighbor(
-    const size_t queryIndex,
-    const size_t neighbor,
-    const double distance)
-{
-  CandidateList& pqueue = candidates[queryIndex];
-  Candidate c = std::make_pair(distance, neighbor);
-
-  if (CandidateCmp()(c, pqueue.top()))
-  {
-    pqueue.pop();
-    pqueue.push(c);
-  }
-}
-
-} // namespace neighbor
-} // namespace mlpack
-
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP




More information about the mlpack-git mailing list