[mlpack-svn] r14540 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Mar 13 15:40:37 EDT 2013


Author: rcurtin
Date: 2013-03-13 15:40:36 -0400 (Wed, 13 Mar 2013)
New Revision: 14540

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Fix bug in bound function for neighbor search.  It is necessary to keep two
separate bounds for each node.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2013-03-13 19:40:36 UTC (rev 14540)
@@ -176,13 +176,12 @@
   // Map the points back to their original locations.
   if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
     Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
-        distancesOut, true);
+        distancesOut);
   else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-    Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut,
-        true);
+    Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
   else
     Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
-        distancesOut, true);
+        distancesOut);
 
   // Clean up.
   if (queryTree)

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2013-03-13 19:40:36 UTC (rev 14540)
@@ -238,13 +238,12 @@
     // Map the results back to the correct places.
     if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
       Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
-          neighbors, distances, true);
+          neighbors, distances);
     else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-      Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances,
-          true);
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
     else
       Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
-          neighbors, distances, true);
+          neighbors, distances);
 
     // Clean up.
     if (queryTree)

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2013-03-13 19:40:36 UTC (rev 14540)
@@ -30,7 +30,15 @@
 class QueryStat
 {
  private:
-  //! The bound on the node's neighbor distances.
+  //! The first bound on the node's neighbor distances (B_1).  This represents
+  //! the worst candidate distance of any descendants of this node.
+  double firstBound;
+  //! The second bound on the node's neighbor distances (B_2).  This represents
+  //! a bound on the worst distance of any descendants of this node assembled
+  //! using the best descendant candidate distance modified by the furthest
+  //! descendant distance.
+  double secondBound;
+  //! The better of the two bounds.
   double bound;
 
  public:
@@ -38,19 +46,32 @@
    * Initialize the statistic with the worst possible distance according to
    * our sorting policy.
    */
-  QueryStat() : bound(SortPolicy::WorstDistance()) { }
+  QueryStat() :
+      firstBound(SortPolicy::WorstDistance()),
+      secondBound(SortPolicy::WorstDistance()),
+      bound(SortPolicy::WorstDistance()) { }
 
   /**
    * Initialization for a fully initialized node.  In this case, we don't need
    * to worry about the node.
    */
   template<typename TreeType>
-  QueryStat(TreeType& /* node */)
-      : bound(SortPolicy::WorstDistance()) { }
+  QueryStat(TreeType& /* node */) :
+      firstBound(SortPolicy::WorstDistance()),
+      secondBound(SortPolicy::WorstDistance()),
+      bound(SortPolicy::WorstDistance()) { }
 
-  //! Get the bound.
+  //! Get the first bound.
+  double FirstBound() const { return firstBound; }
+  //! Modify the first bound.
+  double& FirstBound() { return firstBound; }
+  //! Get the second bound.
+  double SecondBound() const { return secondBound; }
+  //! Modify the second bound.
+  double& SecondBound() { return secondBound; }
+  //! Get the overall bound (the better of the two bounds).
   double Bound() const { return bound; }
-  //! Modify the bound.
+  //! Modify the overall bound (it should be the better of the two bounds).
   double& Bound() { return bound; }
 };
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2013-03-13 19:40:36 UTC (rev 14540)
@@ -43,7 +43,7 @@
   // If this distance is better than any of the current candidates, the
   // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
-  size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+  const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
 
   // SortDistance() returns (size_t() - 1) if we shouldn't add it.
   if (insertPosition != (size_t() - 1))
@@ -63,8 +63,7 @@
       &referenceNode, &referenceChildNode, baseCaseResult);
 
   // Update our bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestDistance = queryNode.Stat().Bound();
+  const double bestDistance = CalculateBound(queryNode);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -80,8 +79,7 @@
       &queryNode, &queryChildNode, baseCaseResult);
 
   // Update our bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestDistance = queryNode.Stat().Bound();
+  const double bestDistance = CalculateBound(queryNode);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -138,8 +136,7 @@
       &referenceNode);
 
   // Update our bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestDistance = queryNode.Stat().Bound();
+  const double bestDistance = CalculateBound(queryNode);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -154,8 +151,7 @@
       &referenceNode, baseCaseResult);
 
   // Update our bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestDistance = queryNode.Stat().Bound();
+  const double bestDistance = CalculateBound(queryNode);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
 }
@@ -170,40 +166,18 @@
     return oldScore;
 
   // Update our bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestDistance = queryNode.Stat().Bound();
+  const double bestDistance = CalculateBound(queryNode);
 
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
 }
-/*
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::FinishNode(
-    TreeType& queryNode) const
-{
-  // Find the bound of points contained in this node.
-  double pointBound = SortPolicy::BestDistance();
-  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
 
-  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
-  {
-    // The bound for this point is the k-th best distance plus the maximum
-    // distance to a child of this node.
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
-        maxDescendantDistance;
-    if (SortPolicy::IsBetter(pointBound, bound))
-      pointBound = bound;
-  }
-
-  // Push bound to parent.
-}
-*/
-
-// Calculate the bound for a given query node in its current state.
+// Calculate the bound for a given query node in its current state and update
+// it.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
     CalculateBound(TreeType& queryNode) const
 {
-  // We have four possible bounds, and we must take the best of them all.  We
+  // We have five possible bounds, and we must take the best of them all.  We
   // don't use min/max here, but instead "best/worst", because this is general
   // to the nearest-neighbors/furthest-neighbors cases.  For nearest neighbors,
   // min = best, max = worst.
@@ -215,11 +189,12 @@
   // (3) best_{all children c in queryNode} B(c) +
   //      2 ( worst descendant distance of queryNode -
   //          worst descendant distance of c );
-  // (4) B(parent of queryNode);
+  // (4) B_1(parent of queryNode)
+  // (5) B_2(parent of queryNode);
   //
   // D_p[k] is the current k'th candidate distance for point p.
   // So we will loop over the points in queryNode and the children in queryNode
-  // to calculate all four of these quantities.
+  // to calculate all five of these quantities.
 
   double worstPointDistance = SortPolicy::BestDistance();
   double bestPointDistance = SortPolicy::WorstDistance();
@@ -245,15 +220,16 @@
 
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
-    const double bound = queryNode.Child(i).Stat().Bound();
+    const double firstBound = queryNode.Child(i).Stat().FirstBound();
+    const double secondBound = queryNode.Child(i).Stat().SecondBound();
     const double childMaxDescendantDistance =
         queryNode.Child(i).FurthestDescendantDistance();
 
-    if (SortPolicy::IsBetter(worstChildBound, bound))
-      worstChildBound = bound;
+    if (SortPolicy::IsBetter(worstChildBound, firstBound))
+      worstChildBound = firstBound;
 
     // Now calculate adjustment for maximum descendant distances.
-    const double adjustedBound = SortPolicy::CombineWorst(bound,
+    const double adjustedBound = SortPolicy::CombineWorst(secondBound,
         2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
     if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
       bestAdjustedChildBound = adjustedBound;
@@ -269,18 +245,38 @@
       2 * queryMaxDescendantDistance);
 
   // Bound (3) is bestAdjustedChildBound.
+
+  // Bounds (4) and (5) are the parent bounds.
   const double fourthBound = (queryNode.Parent() != NULL) ?
-      queryNode.Parent()->Stat().Bound() : SortPolicy::WorstDistance();
+      queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
+  const double fifthBound = (queryNode.Parent() != NULL) ?
+      queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
 
   // Now, we will take the best of these.  Unfortunately due to the way
-  // IsBetter() is defined, this sort of has to be a little ugly...
-  const double interA = (SortPolicy::IsBetter(firstBound, secondBound)) ?
-      firstBound : secondBound;
+  // IsBetter() is defined, this sort of has to be a little ugly.
+  // The variable interA represents the first bound (B_1), which is the worst
+  // candidate distance of any descendants of this node.
+  // The variable interC represents the second bound (B_2), which is a bound on
+  // the worst distance of any descendants of this node assembled using the best
+  // descendant candidate distance modified using the furthest descendant
+  // distance.
+  const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
+      firstBound : fourthBound;
   const double interB =
-      (SortPolicy::IsBetter(bestAdjustedChildBound, fourthBound)) ?
-      bestAdjustedChildBound : fourthBound;
+      (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
+      bestAdjustedChildBound : secondBound;
+  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
+      fifthBound;
 
-  return (SortPolicy::IsBetter(interA, interB)) ? interA : interB;
+  // Update the first and second bounds of the node.
+  queryNode.Stat().FirstBound() = interA;
+  queryNode.Stat().SecondBound() = interC;
+
+  // Update the actual bound of the node.
+  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
+      interC;
+
+  return queryNode.Stat().Bound();
 }
 
 /**




More information about the mlpack-svn mailing list