[mlpack-svn] r14082 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Jan 5 18:27:05 EST 2013


Author: rcurtin
Date: 2013-01-05 18:27:05 -0500 (Sat, 05 Jan 2013)
New Revision: 14082

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Refactor calculation of bound.  It may need a further fix.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2013-01-05 19:14:36 UTC (rev 14081)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2013-01-05 23:27:05 UTC (rev 14082)
@@ -136,6 +136,11 @@
   MetricType& metric;
 
   /**
+   * Recalculate the bound for a given query node.
+   */
+  double CalculateBound(TreeType& queryNode) const;
+
+  /**
    * Insert a point into the neighbors and distances matrices; this is a helper
    * function.
    *

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-01-05 19:14:36 UTC (rev 14081)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-01-05 23:27:05 UTC (rev 14082)
@@ -62,34 +62,8 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode, &referenceChildNode, baseCaseResult);
 
-  // Calculate the bound on the fly.  This bound will be the minimum of
-  // pointBound (the bounds given by the points in this node) and childBound
-  // (the bounds given by the children of this node).
-  double pointBound = DBL_MAX;
-  double childBound = DBL_MAX;
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
-  // Find the bound of the points contained in this node.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
-        maxDescendantDistance;
-    if (bound < pointBound)
-      pointBound = bound;
-  }
-
-  // Find the bound of the children.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    const double bound = queryNode.Child(i).Stat().Bound();
-    if (bound < childBound)
-      childBound = bound;
-  }
-
   // Update our bound.
-  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -105,34 +79,8 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&referenceNode,
       &queryNode, &queryChildNode, baseCaseResult);
 
-  // Calculate the bound on the fly.  This bound will be the minimum of
-  // pointBound (the bounds given by the points in this node) and childBound
-  // (the bounds given by the children of this node).
-  double pointBound = DBL_MAX;
-  double childBound = DBL_MAX;
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
-  // Find the bound of the points contained in this node.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
-        maxDescendantDistance;
-    if (bound < pointBound)
-      pointBound = bound;
-  }
-
-  // Find the bound of the children.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    const double bound = queryNode.Child(i).Stat().Bound();
-    if (bound < childBound)
-      childBound = bound;
-  }
-
   // Update our bound.
-  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -189,34 +137,8 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode);
 
-  // Calculate the bound on the fly.  This bound will be the minimum of
-  // pointBound (the bounds given by the points in this node) and childBound
-  // (the bounds given by the children of this node).
-  double pointBound = DBL_MAX;
-  double childBound = DBL_MAX;
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
-  // Find the bound of the points contained in this node.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
-        maxDescendantDistance;
-    if (bound < pointBound)
-      pointBound = bound;
-  }
-
-  // Find the bound of the children.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    const double bound = queryNode.Child(i).Stat().Bound();
-    if (bound < childBound)
-      childBound = bound;
-  }
-
   // Update our bound.
-  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -231,34 +153,8 @@
   const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
       &referenceNode, baseCaseResult);
 
-  // Calculate the bound on the fly.  This bound will be the minimum of
-  // pointBound (the bounds given by the points in this node) and childBound
-  // (the bounds given by the children of this node).
-  double pointBound = DBL_MAX;
-  double childBound = DBL_MAX;
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
-  // Find the bound of the points contained in this node.
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
-        maxDescendantDistance;
-    if (bound < pointBound)
-      pointBound = bound;
-  }
-
-  // Find the bound of the children.
-  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-  {
-    const double bound = queryNode.Child(i).Stat().Bound();
-    if (bound < childBound)
-      childBound = bound;
-  }
-
   // Update our bound.
-  queryNode.Stat().Bound() = std::min(pointBound, childBound);
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
   const double bestDistance = queryNode.Stat().Bound();
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -273,13 +169,44 @@
   if (oldScore == DBL_MAX)
     return oldScore;
 
-  // Calculate the bound on the fly.  This bound will be the minimum of
-  // pointBound (the bounds given by the points in this node) and childBound
-  // (the bounds given by the children of this node).
-  double pointBound = DBL_MAX;
-  double childBound = DBL_MAX;
+  // Update our bound.
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
+  const double bestDistance = queryNode.Stat().Bound();
+
+  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+/*
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::FinishNode(
+    TreeType& queryNode) const
+{
+  // Find the bound of points contained in this node.
+  double pointBound = SortPolicy::BestDistance();
   const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
 
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    // The bound for this point is the k-th best distance plus the maximum
+    // distance to a child of this node.
+    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+        maxDescendantDistance;
+    if (SortPolicy::IsBetter(pointBound, bound))
+      pointBound = bound;
+  }
+
+  // Push bound to parent.
+}
+*/
+
+// Calculate the bound for a given query node in its current state.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+    CalculateBound(TreeType& queryNode) const
+{
+  double pointBound = SortPolicy::BestDistance();
+  double childBound = SortPolicy::BestDistance();
+  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
   // Find the bound of the points contained in this node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
@@ -287,7 +214,7 @@
     // distance to a child of this node.
     const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
         maxDescendantDistance;
-    if (bound < pointBound)
+    if (SortPolicy::IsBetter(pointBound, bound))
       pointBound = bound;
   }
 
@@ -295,15 +222,27 @@
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
     const double bound = queryNode.Child(i).Stat().Bound();
-    if (bound < childBound)
+    if (SortPolicy::IsBetter(childBound, bound))
       childBound = bound;
   }
 
-  // Update our bound.
-  queryNode.Stat().Bound() = std::min(pointBound, childBound);
-  const double bestDistance = queryNode.Stat().Bound();
+  // If the bound of the children is uninitialized
+  // (SortPolicy::WorstDistance()), then maybe we can create a bound for the
+  // children.  But this requires a point bound to exist.
+  if (childBound == SortPolicy::WorstDistance() &&
+      pointBound != SortPolicy::BestDistance()) // This could fail!
+      // SortPolicy::BestDistance() could be a valid bound!
+  {
+    // Should we be considering queryNode.Stat().Bound() too?
+    childBound = pointBound + maxDescendantDistance;
+    Log::Debug << "Child bound is " << childBound << std::endl;
+  }
 
-  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+  // Return the worse of the two bounds.
+  if (SortPolicy::IsBetter(childBound, pointBound))
+    return pointBound;
+  else
+    return childBound;
 }
 
 /**




More information about the mlpack-svn mailing list