[mlpack-svn] r13339 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Aug 6 14:17:56 EDT 2012
Author: rcurtin
Date: 2012-08-06 14:17:55 -0400 (Mon, 06 Aug 2012)
New Revision: 13339
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add Score() and Rescore(), which will be used for a "new" type of tree traversal
which should be faster.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-05 21:47:28 UTC (rev 13338)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-06 18:17:55 UTC (rev 13339)
@@ -36,6 +36,56 @@
// Update bounds. Needs a better name.
void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
private:
//! The reference set.
const arma::mat& referenceSet;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-05 21:47:28 UTC (rev 13338)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-06 18:17:55 UTC (rev 13339)
@@ -145,6 +145,58 @@
queryNode.Stat().Bound() = worstDistance;
}
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ const size_t queryIndex,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // Just check the score again against the distances.
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ TreeType& queryNode,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
/**
* Helper function to insert a point into the neighbors and distances matrices.
*
More information about the mlpack-svn
mailing list