[mlpack-svn] r15016 - mlpack/trunk/src/mlpack/methods/range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue May 7 11:44:01 EDT 2013


Author: rcurtin
Date: 2013-05-07 11:44:01 -0400 (Tue, 07 May 2013)
New Revision: 15016

Added:
   mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
Log:
Incremental checkin so I can work from another system.  Begins outline of the
RangeSearchRules class.


Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	2013-05-07 15:44:01 UTC (rev 15016)
@@ -0,0 +1,140 @@
+/**
+ * @file range_search_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Rules for range search, so that it can be done with arbitrary tree types.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MetricType, typename TreeType>
+class RangeSearchRules
+{
+ public:
+  RangeSearchRules(const arma::mat& referenceSet,
+                   const arma::mat& querySet,
+                   std::vector<std::vector<size_t> >& neighbors,
+                   std::vector<std::vector<double> >& distances,
+                   math::Range& range,
+                   MetricType& metric);
+
+  /**
+   * Compute the base case between the given query point and reference point.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceIndex Index of reference point.
+   */
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   */
+  double Score(const size_t queryIndex, TreeType& referenceNode);
+
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+   */
+  double Score(const size_t queryIndex,
+               TreeType& referenceNode,
+               const double baseCaseResult);
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore);
+
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   */
+  double Score(TreeType& queryNode, TreeType& referenceNode);
+
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
+   */
+  double Score(TreeType& queryNode,
+               TreeType& referenceNode,
+               const double baseCaseResult);
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryNode Candidate query node to recurse into.
+   * @param referenceNode Candidate reference node to recurse into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore);
+
+ private:
+  //! The reference set.
+  const arma::mat& referenceSet;
+
+  //! The query set.
+  const arma::mat& querySet;
+
+  //! The vector the resultant neighbor indices should be stored in.
+  std::vector<std::vector<size_t> >& neighbors;
+
+  //! The vector the resultant neighbor distances should be stored in.
+  std::vector<std::vector<double> >& distances;
+
+  //! The range of distances for which we are searching.
+  math::Range& range;
+
+  //! The instantiated metric.
+  MetricType& metric;
+
+  //! Add all the points in the given node to the results for the given query
+  //! point.
+  void AddResult(const size_t queryIndex, TreeType& referenceNode);
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "range_search_rules_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp	2013-05-07 15:44:01 UTC (rev 15016)
@@ -0,0 +1,119 @@
+/**
+ * @file range_search_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of rules for range search with generic trees.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "range_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MetricType, typename TreeType>
+RangeSearchRules::RangeSearchRules(const arma::mat& referenceSet,
+                                   const arma::mat& querySet,
+                                   std::vector<std::vector<size_t> >& neighbors,
+                                   std::vector<std::vector<double> >& distances,
+                                   math::Range& range,
+                                   MetricType& metric) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    neighbors(neighbors),
+    distances(distances),
+    range(range),
+    metric(metric)
+{
+  // Nothing to do.
+}
+
+//! The base case.  Evaluate the distance between the two points and add to the
+//! results if necessary.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::BaseCase(const size_t queryIndex,
+                                  const size_t referenceIndex)
+{
+  // If the datasets are the same, don't return the point as in its own range.
+  if ((&referenceSet == &querySet) && (queryIndex == referenceIndex))
+    return 0.0;
+
+  const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+      referenceSet.unsafe_col(referenceIndex));
+
+  if (range.Contains(distance))
+  {
+    neighbors[queryIndex].push_back(referenceIndex);
+    distances[queryIndex].push_back(distance);
+  }
+
+  return distance;
+}
+
+//! Single-tree scoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::Score(const size_t queryIndex,
+                               TreeType& referenceNode)
+{
+  const math::Range distances =
+      referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+
+  // If the ranges do not overlap, prune this node.
+  if (!distances.Contains(range))
+    return DBL_MAX;
+
+  // In this case, all of the points in the reference node will be part of the
+  // results.
+  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+  {
+    AddResult(queryIndex, referenceNode);
+    return DBL_MAX; // We don't need to go any deeper.
+  }
+
+  // Otherwise the score doesn't matter.  Recursion order is irrelevant in range
+  // search.
+  return 0.0;
+}
+
+//! Single-tree scoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::Score(const size_t queryIndex,
+                               TreeType& referenceNode,
+                               const double baseCaseResult)
+{
+  const math::Range distances = referenceNode.RangeDistance(
+      querySet.unsafe_col(queryIndex), baseCaseResult);
+
+  // If the ranges do not overlap, prune this node.
+  if (!distances.Contains(range))
+    return DBL_MAX;
+
+  // In this case, all of the points in the reference node will be part of the
+  // results.
+  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+  {
+    AddResult(queryIndex, referenceNode);
+    return DBL_MAX; // We don't need to go any deeper.
+  }
+
+  // Otherwise the score doesn't matter.  Recursion order is irrelevant in range
+  // search.
+  return 0.0;
+}
+
+//! Single-tree rescoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+                                 TreeType& referenceNode,
+                                 const double oldScore)
+{
+  // If it wasn't pruned before, it isn't pruned now.
+  return oldScore;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list