[mlpack-git] master: Implement B_aux according to what was discussed in #642. (6877607)

gitdub at mlpack.org gitdub at mlpack.org
Tue May 31 14:51:56 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6a94eb6efb37eca09e95fd53bd7c21334abf7614...281555a9c6a460623dd1337800e74d2aa7c9efcc

>---------------------------------------------------------------

commit 68776071c232fb63a6a3f7aff611dc5d8825155d
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Mon May 30 11:49:58 2016 -0300

    Implement B_aux according to what was discussed in #642.


>---------------------------------------------------------------

68776071c232fb63a6a3f7aff611dc5d8825155d
 .../neighbor_search/neighbor_search_rules_impl.hpp | 41 ++++++++++++++--------
 .../neighbor_search/neighbor_search_stat.hpp       | 10 ++++++
 2 files changed, 37 insertions(+), 14 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index c4767ec..cc2b957 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -344,6 +344,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
 
   double worstDistance = SortPolicy::BestDistance();
   double bestDistance = SortPolicy::WorstDistance();
+  double bestPointDistance = SortPolicy::WorstDistance();
+  double auxDistance = SortPolicy::WorstDistance();
 
   // Loop over points held in the node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
@@ -351,33 +353,43 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
     const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
     if (SortPolicy::IsBetter(worstDistance, distance))
       worstDistance = distance;
-    if (SortPolicy::IsBetter(distance, bestDistance))
-      bestDistance = distance;
+    if (SortPolicy::IsBetter(distance, bestPointDistance))
+      bestPointDistance = distance;
   }
 
-  // Add triangle inequality adjustment to best distance.  It is possible this
-  // could be tighter for some certain types of trees.
-  bestDistance = SortPolicy::CombineWorst(bestDistance,
-      queryNode.FurthestPointDistance() +
-      queryNode.FurthestDescendantDistance());
+  auxDistance = bestPointDistance;
 
   // Loop over children of the node, and use their cached information to
   // assemble bounds.
   for (size_t i = 0; i < queryNode.NumChildren(); ++i)
   {
     const double firstBound = queryNode.Child(i).Stat().FirstBound();
-    const double adjustment = std::max(0.0,
-        queryNode.FurthestDescendantDistance() -
-        queryNode.Child(i).FurthestDescendantDistance());
-    const double adjustedSecondBound = SortPolicy::CombineWorst(
-        queryNode.Child(i).Stat().SecondBound(), 2 * adjustment);
+    const double auxBound = queryNode.Child(i).Stat().AuxBound();
 
     if (SortPolicy::IsBetter(worstDistance, firstBound))
       worstDistance = firstBound;
-    if (SortPolicy::IsBetter(adjustedSecondBound, bestDistance))
-      bestDistance = adjustedSecondBound;
+    if (SortPolicy::IsBetter(auxBound, auxDistance))
+      auxDistance = auxBound;
   }
 
+  // Add triangle inequality adjustment to best distance.  It is possible this
+  // could be tighter for some certain types of trees.
+  bestDistance = SortPolicy::CombineWorst(auxDistance,
+      2 * queryNode.FurthestDescendantDistance());
+
+  // Add triangle inequality adjustment to best distance of points in node.
+  bestPointDistance = SortPolicy::CombineWorst(bestPointDistance,
+      queryNode.FurthestPointDistance() +
+      queryNode.FurthestDescendantDistance());
+
+  if (SortPolicy::IsBetter(bestPointDistance, bestDistance))
+    bestDistance = bestPointDistance;
+
+  // At this point:
+  // worstDistance holds the value of B_1(N_q).
+  // bestDistance holds the value of B_2(N_q).
+  // auxDistance holds the value of B_aux(N_q).
+
   // Now consider the parent bounds.
   if (queryNode.Parent() != NULL)
   {
@@ -405,6 +417,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   // Cache bounds for later.
   queryNode.Stat().FirstBound() = worstDistance;
   queryNode.Stat().SecondBound() = bestDistance;
+  queryNode.Stat().AuxBound() = auxDistance;
 
   if (SortPolicy::IsBetter(worstDistance, bestDistance))
     return worstDistance;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp
index 90b3f76..c125369 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp
@@ -29,6 +29,9 @@ class NeighborSearchStat
   //! using the best descendant candidate distance modified by the furthest
   //! descendant distance.
   double secondBound;
+  //! The aux bound on the node's neighbor distances (B_aux). This represents
+  //! the best descendant candidate distance (used to calculate secondBound).
+  double auxBound;
   //! The better of the two bounds.
   double bound;
 
@@ -45,6 +48,7 @@ class NeighborSearchStat
   NeighborSearchStat() :
       firstBound(SortPolicy::WorstDistance()),
       secondBound(SortPolicy::WorstDistance()),
+      auxBound(SortPolicy::WorstDistance()),
       bound(SortPolicy::WorstDistance()),
       lastDistance(0.0) { }
 
@@ -56,6 +60,7 @@ class NeighborSearchStat
   NeighborSearchStat(TreeType& /* node */) :
       firstBound(SortPolicy::WorstDistance()),
       secondBound(SortPolicy::WorstDistance()),
+      auxBound(SortPolicy::WorstDistance()),
       bound(SortPolicy::WorstDistance()),
       lastDistance(0.0) { }
 
@@ -67,6 +72,10 @@ class NeighborSearchStat
   double SecondBound() const { return secondBound; }
   //! Modify the second bound.
   double& SecondBound() { return secondBound; }
+  //! Get the aux bound.
+  double AuxBound() const { return auxBound; }
+  //! Modify the aux bound.
+  double& AuxBound() { return auxBound; }
   //! Get the overall bound (the better of the two bounds).
   double Bound() const { return bound; }
   //! Modify the overall bound (it should be the better of the two bounds).
@@ -84,6 +93,7 @@ class NeighborSearchStat
 
     ar & CreateNVP(firstBound, "firstBound");
     ar & CreateNVP(secondBound, "secondBound");
+    ar & CreateNVP(auxBound, "auxBound");
     ar & CreateNVP(bound, "bound");
     ar & CreateNVP(lastDistance, "lastDistance");
   }




More information about the mlpack-git mailing list