[mlpack-svn] r14376 - in mlpack/trunk/src/mlpack/methods/neighbor_search: . sort_policies
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Feb 22 18:30:47 EST 2013
Author: rcurtin
Date: 2013-02-22 18:30:47 -0500 (Fri, 22 Feb 2013)
New Revision: 14376
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
Log:
Use the bound in the ICML paper.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-02-22 23:30:47 UTC (rev 14376)
@@ -203,76 +203,84 @@
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
CalculateBound(TreeType& queryNode) const
{
- double pointBound = SortPolicy::BestDistance();
- double childBound = SortPolicy::BestDistance();
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+ // We have four possible bounds, and we must take the best of them all. We
+ // don't use min/max here, but instead "best/worst", because this is general
+ // to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors,
+ // min = best, max = worst.
+ //
+ // (1) worst ( worst_{all points p in queryNode} D_p[k],
+ // worst_{all children c in queryNode} B(c) );
+ // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
+ // worst descendant distance;
+ // (3) best_{all children c in queryNode} B(c) +
+ // 2 ( worst descendant distance of queryNode -
+ // worst descendant distance of c );
+ // (4) B(parent of queryNode);
+ //
+ // D_p[k] is the current k'th candidate distance for point p.
+ // So we will loop over the points in queryNode and the children in queryNode
+ // to calculate all four of these quantities.
- // Find the bound of the points contained in this node.
+ double worstPointDistance = SortPolicy::BestDistance();
+ double bestPointDistance = SortPolicy::WorstDistance();
+
+ // Loop over all points in this node to find the best and worst distance
+ // candidates (for (1) and (2)).
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i));
- if (SortPolicy::IsBetter(pointBound, bound))
- pointBound = bound;
+ const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+ if (SortPolicy::IsBetter(distance, bestPointDistance))
+ bestPointDistance = distance;
+ if (SortPolicy::IsBetter(worstPointDistance, distance))
+ worstPointDistance = distance;
}
- // Find the bound of the children.
+ // Loop over all the children in this node to find the worst bound (for (1))
+ // and the best bound with the correcting factor for descendant distances (for
+ // (3)).
+ double worstChildBound = SortPolicy::BestDistance();
+ double bestAdjustedChildBound = SortPolicy::WorstDistance();
+ const double queryMaxDescendantDistance =
+ queryNode.FurthestDescendantDistance();
+
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
const double bound = queryNode.Child(i).Stat().Bound();
- if (SortPolicy::IsBetter(childBound, bound))
- childBound = bound;
- }
+ const double childMaxDescendantDistance =
+ queryNode.Child(i).FurthestDescendantDistance();
- // If there are no points, then break; the bound must be the child bound.
- if (queryNode.NumPoints() == 0)
- return childBound;
+ if (SortPolicy::IsBetter(worstChildBound, bound))
+ worstChildBound = bound;
- // If there are no children, then break; the bound must be the point bound.
- if (queryNode.NumChildren() == 0)
- return pointBound;
+ // Now calculate adjustment for maximum descendant distances.
+ const double adjustedBound = SortPolicy::CombineWorst(bound,
+ 2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
+ if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
+ bestAdjustedChildBound = adjustedBound;
+ }
-// Log::Debug << "Point bound " << pointBound << std::endl;
-// Log::Debug << "Child bound " << childBound << std::endl;
-// Log::Debug << "Furthest descendant distance " << maxDescendantDistance <<
-// std::endl;
+ // This is bound (1).
+ const double firstBound =
+ (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
+ worstChildBound : worstPointDistance;
- // If the bound of the children is uninitialized
- // (SortPolicy::WorstDistance()), then maybe we can create a bound for the
- // children. But this requires a point bound to exist.
+ // This is bound (2).
+ const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
+ 2 * queryMaxDescendantDistance);
- // It is possible that we could calculate a better bound for the children.
- if (pointBound != SortPolicy::WorstDistance())
- {
- const double pointChildBound = pointBound + maxDescendantDistance;
-// Log::Debug << "Point-child bound is " << pointChildBound << std::endl;
+ // Bound (3) is bestAdjustedChildBound.
+ const double fourthBound = (queryNode.Parent() != NULL) ?
+ queryNode.Parent()->Stat().Bound() : SortPolicy::WorstDistance();
- if (SortPolicy::IsBetter(pointChildBound, childBound))
- {
- // The calculated bound is a tighter bound than the existing child bounds.
- // Update all of the child bounds to this new, tighter bound.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
-// Log::Debug << "Update child " << i << " bound from " <<
-// queryNode.Child(i).Stat().Bound() << " to " << pointChildBound <<
-// std::endl;
- if (SortPolicy::IsBetter(pointChildBound,
- queryNode.Child(i).Stat().Bound()))
- queryNode.Child(i).Stat().Bound() = pointChildBound;
-// else
-// Log::Debug << "Did not update child!\n";
- }
+ // Now, we will take the best of these. Unfortunately due to the way
+ // IsBetter() is defined, this sort of has to be a little ugly...
+ const double interA = (SortPolicy::IsBetter(firstBound, secondBound)) ?
+ firstBound : secondBound;
+ const double interB =
+ (SortPolicy::IsBetter(bestAdjustedChildBound, fourthBound)) ?
+ bestAdjustedChildBound : fourthBound;
- childBound = pointChildBound;
- }
- }
-
- // Return the worse of the two bounds.
- if (SortPolicy::IsBetter(childBound, pointBound))
- return pointBound;
- else
- return childBound;
+ return (SortPolicy::IsBetter(interA, interB)) ? interA : interB;
}
/**
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2013-02-22 23:30:47 UTC (rev 14376)
@@ -127,6 +127,12 @@
* @return DBL_MAX
*/
static inline double BestDistance() { return DBL_MAX; }
+
+ /**
+ * Return the worst combination of the two distances.
+ */
+ static inline double CombineWorst(const double a, const double b)
+ { return std::max(a - b, 0.0); }
};
}; // namespace neighbor
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2013-02-22 23:15:20 UTC (rev 14375)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2013-02-22 23:30:47 UTC (rev 14376)
@@ -130,6 +130,15 @@
* @return 0.0
*/
static inline double BestDistance() { return 0.0; }
+
+ /**
+ * Return the worst combination of the two distances.
+ */
+ static inline double CombineWorst(const double a, const double b)
+ {
+ if (a == DBL_MAX || b == DBL_MAX)
+ return DBL_MAX;
+ return a + b; }
};
}; // namespace neighbor
More information about the mlpack-svn
mailing list