[mlpack-svn] r14082 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Jan 5 18:27:05 EST 2013
Author: rcurtin
Date: 2013-01-05 18:27:05 -0500 (Sat, 05 Jan 2013)
New Revision: 14082
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Refactor calculation of bound. It may need a further fix.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2013-01-05 19:14:36 UTC (rev 14081)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2013-01-05 23:27:05 UTC (rev 14082)
@@ -136,6 +136,11 @@
MetricType& metric;
/**
+ * Recalculate the bound for a given query node.
+ */
+ double CalculateBound(TreeType& queryNode) const;
+
+ /**
* Insert a point into the neighbors and distances matrices; this is a helper
* function.
*
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-01-05 19:14:36 UTC (rev 14081)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-01-05 23:27:05 UTC (rev 14082)
@@ -62,34 +62,8 @@
const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
&referenceNode, &referenceChildNode, baseCaseResult);
- // Calculate the bound on the fly. This bound will be the minimum of
- // pointBound (the bounds given by the points in this node) and childBound
- // (the bounds given by the children of this node).
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- // Find the bound of the points contained in this node.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
- maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- // Find the bound of the children.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
// Update our bound.
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
const double bestDistance = queryNode.Stat().Bound();
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -105,34 +79,8 @@
const double distance = SortPolicy::BestNodeToNodeDistance(&referenceNode,
&queryNode, &queryChildNode, baseCaseResult);
- // Calculate the bound on the fly. This bound will be the minimum of
- // pointBound (the bounds given by the points in this node) and childBound
- // (the bounds given by the children of this node).
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- // Find the bound of the points contained in this node.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
- maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- // Find the bound of the children.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
// Update our bound.
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
const double bestDistance = queryNode.Stat().Bound();
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -189,34 +137,8 @@
const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
&referenceNode);
- // Calculate the bound on the fly. This bound will be the minimum of
- // pointBound (the bounds given by the points in this node) and childBound
- // (the bounds given by the children of this node).
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- // Find the bound of the points contained in this node.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
- maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- // Find the bound of the children.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
// Update our bound.
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
const double bestDistance = queryNode.Stat().Bound();
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -231,34 +153,8 @@
const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
&referenceNode, baseCaseResult);
- // Calculate the bound on the fly. This bound will be the minimum of
- // pointBound (the bounds given by the points in this node) and childBound
- // (the bounds given by the children of this node).
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- // Find the bound of the points contained in this node.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
- maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- // Find the bound of the children.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
// Update our bound.
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
const double bestDistance = queryNode.Stat().Bound();
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -273,13 +169,44 @@
if (oldScore == DBL_MAX)
return oldScore;
- // Calculate the bound on the fly. This bound will be the minimum of
- // pointBound (the bounds given by the points in this node) and childBound
- // (the bounds given by the children of this node).
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
+ // Update our bound.
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+/*
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::FinishNode(
+ TreeType& queryNode) const
+{
+ // Find the bound of points contained in this node.
+ double pointBound = SortPolicy::BestDistance();
const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (SortPolicy::IsBetter(pointBound, bound))
+ pointBound = bound;
+ }
+
+ // Push bound to parent.
+}
+*/
+
+// Calculate the bound for a given query node in its current state.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+ CalculateBound(TreeType& queryNode) const
+{
+ double pointBound = SortPolicy::BestDistance();
+ double childBound = SortPolicy::BestDistance();
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
// Find the bound of the points contained in this node.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
@@ -287,7 +214,7 @@
// distance to a child of this node.
const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
maxDescendantDistance;
- if (bound < pointBound)
+ if (SortPolicy::IsBetter(pointBound, bound))
pointBound = bound;
}
@@ -295,15 +222,27 @@
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
+ if (SortPolicy::IsBetter(childBound, bound))
childBound = bound;
}
- // Update our bound.
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
- const double bestDistance = queryNode.Stat().Bound();
+ // If the bound of the children is uninitialized
+ // (SortPolicy::WorstDistance()), then maybe we can create a bound for the
+ // children. But this requires a point bound to exist.
+ if (childBound == SortPolicy::WorstDistance() &&
+ pointBound != SortPolicy::BestDistance()) // This could fail!
+ // SortPolicy::BestDistance() could be a valid bound!
+ {
+ // Should we be considering queryNode.Stat().Bound() too?
+ childBound = pointBound + maxDescendantDistance;
+ Log::Debug << "Child bound is " << childBound << std::endl;
+ }
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+ // Return the worse of the two bounds.
+ if (SortPolicy::IsBetter(childBound, pointBound))
+ return pointBound;
+ else
+ return childBound;
}
/**
More information about the mlpack-svn
mailing list