[mlpack-svn] r13356 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 7 01:27:05 EDT 2012
Author: rcurtin
Date: 2012-08-07 01:27:05 -0400 (Tue, 07 Aug 2012)
New Revision: 13356
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Remove a few unused methods (LeftFirst() variants). Add Score() which accepts a
pre-calculated base case, which is useful for trees with self-children, like the
cover tree.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-07 05:26:32 UTC (rev 13355)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-07 05:27:05 UTC (rev 13356)
@@ -21,7 +21,7 @@
arma::mat& distances,
MetricType& metric);
- void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
// For single-tree traversal.
bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
@@ -29,10 +29,6 @@
// For dual-tree traversal.
bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
- // Get the order of points to recurse to.
- bool LeftFirst(const size_t queryIndex, TreeType& referenceNode);
- bool LeftFirst(TreeType& staticNode, TreeType& recurseNode);
-
// Update bounds. Needs a better name.
void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
@@ -47,6 +43,20 @@
double Score(const size_t queryIndex, TreeType& referenceNode) const;
/**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned). This is used when the score has already
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-07 05:26:32 UTC (rev 13355)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-07 05:27:05 UTC (rev 13356)
@@ -28,13 +28,14 @@
{ /* Nothing left to do. */ }
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+inline force_inline // Absolutely MUST be inline so optimizations can happen.
+double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return;
+ return 0.0;
double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
@@ -47,6 +48,8 @@
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+ return distance;
}
template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -62,10 +65,7 @@
// If this is better than the best distance we've seen so far, maybe there
// will be something down this node.
- if (SortPolicy::IsBetter(distance, bestDistance))
- return false; // We cannot prune.
- else
- return true; // There cannot be anything better in this node. So prune it.
+ return !(SortPolicy::IsBetter(distance, bestDistance));
}
template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -77,43 +77,10 @@
&queryNode, &referenceNode);
const double bestDistance = queryNode.Stat().Bound();
- if (SortPolicy::IsBetter(distance, bestDistance))
- return false; // Can't prune.
- else
- return true;
+ return !(SortPolicy::IsBetter(distance, bestDistance));
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
- const size_t queryIndex,
- TreeType& referenceNode)
-{
- // This ends up with us calculating this distance twice (it will be done again
- // in CanPrune()), but because single-neighbors recursion is not the most
- // important in this method, we can let it slide.
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
- referenceNode.Left());
- const double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
- referenceNode.Right());
-
- return SortPolicy::IsBetter(leftDistance, rightDistance);
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
- TreeType& staticNode,
- TreeType& recurseNode)
-{
- const double leftDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
- recurseNode.Left());
- const double rightDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
- recurseNode.Right());
-
- return SortPolicy::IsBetter(leftDistance, rightDistance);
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
void NeighborSearchRules<
SortPolicy,
MetricType,
@@ -158,6 +125,20 @@
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode, baseCaseResult);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
const size_t queryIndex,
TreeType& /* referenceNode */,
More information about the mlpack-svn
mailing list