[mlpack-svn] r13356 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 01:27:05 EDT 2012


Author: rcurtin
Date: 2012-08-07 01:27:05 -0400 (Tue, 07 Aug 2012)
New Revision: 13356

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Remove a few unused methods (LeftFirst() variants).  Add Score() which accepts a
pre-calculated base case, which is useful for trees with self-children, like the
cover tree.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-07 05:26:32 UTC (rev 13355)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-07 05:27:05 UTC (rev 13356)
@@ -21,7 +21,7 @@
                       arma::mat& distances,
                       MetricType& metric);
 
-  void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
   // For single-tree traversal.
   bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
@@ -29,10 +29,6 @@
   // For dual-tree traversal.
   bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
 
-  // Get the order of points to recurse to.
-  bool LeftFirst(const size_t queryIndex, TreeType& referenceNode);
-  bool LeftFirst(TreeType& staticNode, TreeType& recurseNode);
-
   // Update bounds.  Needs a better name.
   void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
 
@@ -47,6 +43,20 @@
   double Score(const size_t queryIndex, TreeType& referenceNode) const;
 
   /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+   */
+  double Score(const size_t queryIndex,
+               TreeType& referenceNode,
+               const double baseCaseResult) const;
+
+  /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
    * for recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).  This is used when the score has already

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-07 05:26:32 UTC (rev 13355)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-07 05:27:05 UTC (rev 13356)
@@ -28,13 +28,14 @@
 { /* Nothing left to do. */ }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+inline force_inline // Absolutely MUST be inline so optimizations can happen.
+double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
 BaseCase(const size_t queryIndex, const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
   if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
-    return;
+    return 0.0;
 
   double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
                                     referenceSet.unsafe_col(referenceIndex));
@@ -47,6 +48,8 @@
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))
     InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+  return distance;
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -62,10 +65,7 @@
 
   // If this is better than the best distance we've seen so far, maybe there
   // will be something down this node.
-  if (SortPolicy::IsBetter(distance, bestDistance))
-    return false; // We cannot prune.
-  else
-    return true; // There cannot be anything better in this node.  So prune it.
+  return !(SortPolicy::IsBetter(distance, bestDistance));
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -77,43 +77,10 @@
       &queryNode, &referenceNode);
   const double bestDistance = queryNode.Stat().Bound();
 
-  if (SortPolicy::IsBetter(distance, bestDistance))
-    return false; // Can't prune.
-  else
-    return true;
+  return !(SortPolicy::IsBetter(distance, bestDistance));
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
-    const size_t queryIndex,
-    TreeType& referenceNode)
-{
-  // This ends up with us calculating this distance twice (it will be done again
-  // in CanPrune()), but because single-neighbors recursion is not the most
-  // important in this method, we can let it slide.
-  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-  const double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
-      referenceNode.Left());
-  const double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
-      referenceNode.Right());
-
-  return SortPolicy::IsBetter(leftDistance, rightDistance);
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
-    TreeType& staticNode,
-    TreeType& recurseNode)
-{
-  const double leftDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
-      recurseNode.Left());
-  const double rightDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
-      recurseNode.Right());
-
-  return SortPolicy::IsBetter(leftDistance, rightDistance);
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
 void NeighborSearchRules<
     SortPolicy,
     MetricType,
@@ -158,6 +125,20 @@
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode,
+    const double baseCaseResult) const
+{
+  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+      &referenceNode, baseCaseResult);
+  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     const size_t queryIndex,
     TreeType& /* referenceNode */,




More information about the mlpack-svn mailing list