[mlpack-svn] r13975 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 4 18:22:29 EST 2012
Author: rcurtin
Date: 2012-12-04 18:22:28 -0500 (Tue, 04 Dec 2012)
New Revision: 13975
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add Prescore() function which can be used to prune based off of parent
distances.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-12-04 23:22:13 UTC (rev 13974)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-12-04 23:22:28 UTC (rev 13975)
@@ -24,6 +24,20 @@
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
/**
+ * Get the score for the recursion order, in general before the base case is
+ * computed. This is useful for cover trees or other trees that can cache
+ * some statistic that could be used to make a prune of a child before its
+ * base case is computed.
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ */
+ double Prescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const;
+
+ /**
* Get the score for recursion order. A low score indicates priority for
* recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned).
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-12-04 23:22:13 UTC (rev 13974)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-12-04 23:22:28 UTC (rev 13975)
@@ -53,6 +53,49 @@
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Prescore(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode, &referenceChildNode, baseCaseResult);
+
+ // Calculate the bound on the fly. This bound will be the minimum of
+ // pointBound (the bounds given by the points in this node) and childBound
+ // (the bounds given by the children of this node).
+ double pointBound = DBL_MAX;
+ double childBound = DBL_MAX;
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Find the bound of the points contained in this node.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (bound < pointBound)
+ pointBound = bound;
+ }
+
+ // Find the bound of the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (bound < childBound)
+ childBound = bound;
+ }
+
+ // Update our bound.
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
const size_t queryIndex,
TreeType& referenceNode) const
More information about the mlpack-svn
mailing list