[mlpack-git] master: Simplify spill dual tree traverser to use GetBestChild. (bfd66ec)

gitdub at mlpack.org gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e

>---------------------------------------------------------------

commit bfd66ec8198dedaa20f0a51cd4b7e5b36bd0f8c5
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Fri Aug 19 14:47:45 2016 -0300

    Simplify spill dual tree traverser to use GetBestChild.


>---------------------------------------------------------------

bfd66ec8198dedaa20f0a51cd4b7e5b36bd0f8c5
 .../spill_tree/spill_dual_tree_traverser_impl.hpp  | 57 ++++++++++------------
 src/mlpack/core/tree/spill_tree/spill_tree.hpp     |  9 ----
 .../core/tree/spill_tree/spill_tree_impl.hpp       | 41 ----------------
 .../neighbor_search/neighbor_search_rules.hpp      | 11 +++--
 .../neighbor_search/neighbor_search_rules_impl.hpp |  8 +++
 .../sort_policies/furthest_neighbor_sort.hpp       | 11 +++++
 .../sort_policies/nearest_neighbor_sort.hpp        | 11 +++++
 7 files changed, 65 insertions(+), 83 deletions(-)

diff --git a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
index 6853b0d..c331451 100644
--- a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
@@ -105,13 +105,12 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
     if (Defeatist && referenceNode.Overlap())
     {
       // If referenceNode is a overlapping node let's do defeatist search.
-      bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(queryNode);
-      bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
-          queryNode);
-      if (traverseLeft && !traverseRight)
-        Traverse(queryNode, *referenceNode.Left());
-      else if (!traverseLeft && traverseRight)
-        Traverse(queryNode, *referenceNode.Right());
+      SpillTree* bestChild = rule.GetBestChild(queryNode, referenceNode);
+      if (bestChild)
+      {
+        Traverse(queryNode, *bestChild);
+        ++numPrunes;
+      }
       else
       {
         // If we can't decide which child node to traverse, this means that
@@ -147,14 +146,15 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
 
       if (leftScore < rightScore)
       {
-        // Recurse to the left.  Restore the left traversal info.  Store the right
-        // traversal info.
+        // Recurse to the left.  Restore the left traversal info.  Store the
+        // right traversal info.
         traversalInfo = rule.TraversalInfo();
         rule.TraversalInfo() = leftInfo;
         Traverse(queryNode, *referenceNode.Left());
 
         // Is it still valid to recurse to the right?
-        rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+        rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+            rightScore);
 
         if (rightScore != DBL_MAX)
         {
@@ -216,14 +216,13 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
     if (Defeatist && referenceNode.Overlap())
     {
       // If referenceNode is a overlapping node let's do defeatist search.
-      bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
-          *queryNode.Left());
-      bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
-          *queryNode.Left());
-      if (traverseLeft && !traverseRight)
-        Traverse(*queryNode.Left(), *referenceNode.Left());
-      else if (!traverseLeft && traverseRight)
-        Traverse(*queryNode.Left(), *referenceNode.Right());
+      SpillTree* bestChild = rule.GetBestChild(*queryNode.Left(),
+          referenceNode);
+      if (bestChild)
+      {
+        Traverse(*queryNode.Left(), *bestChild);
+        ++numPrunes;
+      }
       else
       {
         // If we can't decide which child node to traverse, this means that
@@ -232,14 +231,12 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
         Traverse(*queryNode.Left(), referenceNode);
       }
 
-      traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
-          *queryNode.Right());
-      traverseRight = referenceNode.Right()->HalfSpaceIntersects(
-          *queryNode.Right());
-      if (traverseLeft && !traverseRight)
-        Traverse(*queryNode.Right(), *referenceNode.Left());
-      else if (!traverseLeft && traverseRight)
-        Traverse(*queryNode.Right(), *referenceNode.Right());
+      bestChild = rule.GetBestChild(*queryNode.Right(), referenceNode);
+      if (bestChild)
+      {
+        Traverse(*queryNode.Right(), *bestChild);
+        ++numPrunes;
+      }
       else
       {
         // If we can't decide which child node to traverse, this means that
@@ -263,8 +260,8 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
 
       if (leftScore < rightScore)
       {
-        // Recurse to the left.  Restore the left traversal info.  Store the right
-        // traversal info.
+        // Recurse to the left.  Restore the left traversal info.  Store the
+        // right traversal info.
         rightInfo = rule.TraversalInfo();
         rule.TraversalInfo() = leftInfo;
         Traverse(*queryNode.Left(), *referenceNode.Left());
@@ -341,8 +338,8 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
 
       if (leftScore < rightScore)
       {
-        // Recurse to the left.  Restore the left traversal info.  Store the right
-        // traversal info.
+        // Recurse to the left.  Restore the left traversal info.  Store the
+        // right traversal info.
         rightInfo = rule.TraversalInfo();
         rule.TraversalInfo() = leftInfo;
         Traverse(*queryNode.Right(), *referenceNode.Left());
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index 42cd002..452ebfb 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -359,15 +359,6 @@ class SpillTree
    */
   size_t Point(const size_t index) const;
 
-  //! Determines if the node's half space intersects the given node.
-  bool HalfSpaceIntersects(const SpillTree& other) const;
-
-  //! Determines if the node's half space contains the given point.
-  template<typename VecType>
-  bool HalfSpaceContains(
-      const VecType& point,
-      typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
-
   //! Return the minimum distance to another node.
   ElemType MinDistance(const SpillTree* other) const
   {
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 42ba1ba..b60e84e 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -601,47 +601,6 @@ template<typename MetricType,
          template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitMetricType, typename SplitMatType>
              class SplitType>
-inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
-    SplitType>::HalfSpaceIntersects(const SpillTree& other) const
-{
-  if (!Parent())
-    return true;
-
-  const bool left = this == Parent()->Left();
-
-  if (left)
-    return !Parent()->Hyperplane().Right(other.Bound());
-  else
-    return !Parent()->Hyperplane().Left(other.Bound());
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitMetricType, typename SplitMatType>
-             class SplitType>
-template<typename VecType>
-inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
-    SplitType>::HalfSpaceContains(
-    const VecType& point,
-    typename boost::enable_if<IsVector<VecType> >::type*) const
-{
-  if (!Parent())
-    return true;
-
-  const bool left = this == Parent()->Left();
-  const bool toTheLeft = Parent()->Hyperplane().Left(point);
-
-  return left == toTheLeft;
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitMetricType, typename SplitMatType>
-             class SplitType>
 void SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
     SplitNode(arma::Col<size_t>& points,
               const size_t maxLeafSize,
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 55d2163..ec61fd1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -85,6 +85,14 @@ class NeighborSearchRules
   TreeType& GetBestChild(const size_t queryIndex, TreeType& referenceNode);
 
   /**
+   * Get the child node with the best score.
+   *
+   * @param queryNode Node to be considered.
+   * @param referenceNode Candidate node to be recursed into.
+   */
+  TreeType* GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
+
+  /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
    * for recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).  This is used when the score has already
@@ -142,9 +150,6 @@ class NeighborSearchRules
   //! Modify the traversal info.
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
 
-  //! Access the query set.
-  const typename TreeType::Mat& QuerySet() { return querySet; }
-
  protected:
   //! The reference set.
   const typename TreeType::Mat& referenceSet;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index e06683a..6bf011b 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -154,6 +154,14 @@ GetBestChild(const size_t queryIndex, TreeType& referenceNode)
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+inline TreeType* NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
+{
+  ++scores;
+  return SortPolicy::GetBestChild(queryNode, referenceNode);
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     const size_t queryIndex,
     TreeType& /* referenceNode */,
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
index 2c55a75..4e855a7 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
@@ -106,6 +106,17 @@ class FurthestNeighborSort
   };
 
   /**
+   * Return the best child according to this sort policy. In this case it will
+   * return the one with the maximum distance.
+   */
+  template<typename TreeType>
+  static TreeType* GetBestChild(const TreeType& queryNode,
+                                TreeType& referenceNode)
+  {
+    return referenceNode.GetFurthestChild(queryNode);
+  };
+
+  /**
    * Return what should represent the worst possible distance with this
    * particular sort policy.  In our case, this should be the minimum possible
    * distance, 0.
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
index 90837ce..7a0ac57 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
@@ -110,6 +110,17 @@ class NearestNeighborSort
   };
 
   /**
+   * Return the best child according to this sort policy. In this case it will
+   * return the one with the minimum distance.
+   */
+  template<typename TreeType>
+  static TreeType* GetBestChild(const TreeType& queryNode,
+                                TreeType& referenceNode)
+  {
+    return referenceNode.GetNearestChild(queryNode);
+  };
+
+  /**
    * Return what should represent the worst possible distance with this
    * particular sort policy.  In our case, this should be the maximum possible
    * distance, DBL_MAX.




More information about the mlpack-git mailing list