[mlpack-svn] r13975 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 4 18:22:29 EST 2012


Author: rcurtin
Date: 2012-12-04 18:22:28 -0500 (Tue, 04 Dec 2012)
New Revision: 13975

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add Prescore() function which can be used to prune based off of parent
distances.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-12-04 23:22:13 UTC (rev 13974)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-12-04 23:22:28 UTC (rev 13975)
@@ -24,6 +24,20 @@
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
   /**
+   * Get the score for the recursion order, in general before the base case is
+   * computed.  This is useful for cover trees or other trees that can cache
+   * some statistic that could be used to make a prune of a child before its
+   * base case is computed.
+   *
+   * @param queryNode Query node.
+   * @param referenceNode Reference node.
+   */
+  double Prescore(TreeType& queryNode,
+                  TreeType& referenceNode,
+                  TreeType& referenceChildNode,
+                  const double baseCaseResult) const;
+
+  /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-12-04 23:22:13 UTC (rev 13974)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-12-04 23:22:28 UTC (rev 13975)
@@ -53,6 +53,49 @@
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Prescore(
+    TreeType& queryNode,
+    TreeType& referenceNode,
+    TreeType& referenceChildNode,
+    const double baseCaseResult) const
+{
+  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+      &referenceNode, &referenceChildNode, baseCaseResult);
+
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = DBL_MAX;
+  double childBound = DBL_MAX;
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (bound < pointBound)
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (bound < childBound)
+      childBound = bound;
+  }
+
+  // Update our bound.
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  const double bestDistance = queryNode.Stat().Bound();
+
+  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     const size_t queryIndex,
     TreeType& referenceNode) const




More information about the mlpack-svn mailing list