[mlpack-svn] r13958 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 29 19:12:46 EST 2012


Author: rcurtin
Date: 2012-11-29 19:12:45 -0500 (Thu, 29 Nov 2012)
New Revision: 13958

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Calculate queryNode.Stat().Bound() on the fly instead of just one time at
UpdateAfterRecursion().  This slows things down, but can probably be modified to
be faster (maybe it is not necessary in Rescore()?).


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-11-30 00:10:12 UTC (rev 13957)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-11-30 00:12:45 UTC (rev 13958)
@@ -23,9 +23,6 @@
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
-  // Update bounds.  Needs a better name.
-  void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
   /**
    * Get the score for recursion order.  A low score indicates priority for
    * recursion, while DBL_MAX indicates that the node should not be recursed

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-11-30 00:10:12 UTC (rev 13957)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-11-30 00:12:45 UTC (rev 13958)
@@ -53,37 +53,6 @@
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<
-    SortPolicy,
-    MetricType,
-    TreeType>::
-UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
-{
-  // Find the worst distance that the children found (including any points), and
-  // update the bound accordingly.
-  double worstDistance = SortPolicy::BestDistance();
-
-  // First look through children nodes.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
-      worstDistance = queryNode.Child(i).Stat().Bound();
-  }
-
-  // Now look through children points.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    if (SortPolicy::IsBetter(worstDistance,
-        distances(distances.n_rows - 1, queryNode.Point(i))))
-      worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
-  }
-
-  // Take the worst distance from all of these, and update our bound to reflect
-  // that.
-  queryNode.Stat().Bound() = worstDistance;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
     const size_t queryIndex,
     TreeType& referenceNode) const
@@ -133,6 +102,35 @@
 {
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode);
+
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = DBL_MAX;
+  double childBound = DBL_MAX;
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (bound < pointBound)
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (bound < childBound)
+      childBound = bound;
+  }
+
+  // Update our bound.
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -146,6 +144,35 @@
 {
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode, baseCaseResult);
+
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = DBL_MAX;
+  double childBound = DBL_MAX;
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (bound < pointBound)
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (bound < childBound)
+      childBound = bound;
+  }
+
+  // Update our bound.
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -160,6 +187,34 @@
   if (oldScore == DBL_MAX)
     return oldScore;
 
+  // Calculate the bound on the fly.  This bound will be the minimum of
+  // pointBound (the bounds given by the points in this node) and childBound
+  // (the bounds given by the children of this node).
+  double pointBound = DBL_MAX;
+  double childBound = DBL_MAX;
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Find the bound of the points contained in this node.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (bound < pointBound)
+      pointBound = bound;
+  }
+
+  // Find the bound of the children.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    const double bound = queryNode.Child(i).Stat().Bound();
+    if (bound < childBound)
+      childBound = bound;
+  }
+
+  // Update our bound.
+  queryNode.Stat().Bound() = std::min(pointBound, childBound);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;




More information about the mlpack-svn mailing list