[mlpack-svn] r13339 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 14:17:56 EDT 2012


Author: rcurtin
Date: 2012-08-06 14:17:55 -0400 (Mon, 06 Aug 2012)
New Revision: 13339

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add Score() and Rescore(), which will be used for a "new" type of tree traversal
which should be faster.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-05 21:47:28 UTC (rev 13338)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-06 18:17:55 UTC (rev 13339)
@@ -36,6 +36,56 @@
   // Update bounds.  Needs a better name.
   void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
 
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   */
+  double Score(const size_t queryIndex, TreeType& referenceNode) const;
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursionm while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   */
+  double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   * @param oldScore Old score produced by Socre() (or Rescore()).
+   */
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-05 21:47:28 UTC (rev 13338)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-06 18:17:55 UTC (rev 13339)
@@ -145,6 +145,58 @@
   queryNode.Stat().Bound() = worstDistance;
 }
 
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode) const
+{
+  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+      &referenceNode);
+  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+    const size_t queryIndex,
+    TreeType& /* referenceNode */,
+    const double oldScore) const
+{
+  // If we are already pruning, still prune.
+  if (oldScore == DBL_MAX)
+    return oldScore;
+
+  // Just check the score again against the distances.
+  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+    TreeType& queryNode,
+    TreeType& referenceNode) const
+{
+  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+      &referenceNode);
+  const double bestDistance = queryNode.Stat().Bound();
+
+  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+    TreeType& queryNode,
+    TreeType& /* referenceNode */,
+    const double oldScore) const
+{
+  const double bestDistance = queryNode.Stat().Bound();
+
+  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
 /**
  * Helper function to insert a point into the neighbors and distances matrices.
  *




More information about the mlpack-svn mailing list