[mlpack-svn] r15016 - mlpack/trunk/src/mlpack/methods/range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue May 7 11:44:01 EDT 2013
Author: rcurtin
Date: 2013-05-07 11:44:01 -0400 (Tue, 07 May 2013)
New Revision: 15016
Added:
mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
Log:
Incremental checkin so I can work from another system. Begins outline of the
RangeSearchRules class.
Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp 2013-05-07 15:44:01 UTC (rev 15016)
@@ -0,0 +1,140 @@
+/**
+ * @file range_search_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Rules for range search, so that it can be done with arbitrary tree types.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MetricType, typename TreeType>
+class RangeSearchRules
+{
+ public:
+ RangeSearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances,
+ math::Range& range,
+ MetricType& metric);
+
+ /**
+ * Compute the base case between the given query point and reference point.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceIndex Index of reference point.
+ */
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ private:
+ //! The reference set.
+ const arma::mat& referenceSet;
+
+ //! The query set.
+ const arma::mat& querySet;
+
+ //! The vector the resultant neighbor indices should be stored in.
+ std::vector<std::vector<size_t> >& neighbors;
+
+ //! The vector the resultant neighbor distances should be stored in.
+ std::vector<std::vector<double> >& distances;
+
+ //! The range of distances for which we are searching.
+ math::Range& range;
+
+ //! The instantiated metric.
+ MetricType& metric;
+
+ //! Add all the points in the given node to the results for the given query
+ //! point.
+ void AddResult(const size_t queryIndex, TreeType& referenceNode);
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "range_search_rules_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp 2013-05-07 15:44:01 UTC (rev 15016)
@@ -0,0 +1,119 @@
+/**
+ * @file range_search_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of rules for range search with generic trees.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "range_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MetricType, typename TreeType>
+RangeSearchRules::RangeSearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances,
+ math::Range& range,
+ MetricType& metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ neighbors(neighbors),
+ distances(distances),
+ range(range),
+ metric(metric)
+{
+ // Nothing to do.
+}
+
+//! The base case. Evaluate the distance between the two points and add to the
+//! results if necessary.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::BaseCase(const size_t queryIndex,
+ const size_t referenceIndex)
+{
+ // If the datasets are the same, don't return the point as in its own range.
+ if ((&referenceSet == &querySet) && (queryIndex == referenceIndex))
+ return 0.0;
+
+ const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ if (range.Contains(distance))
+ {
+ neighbors[queryIndex].push_back(referenceIndex);
+ distances[queryIndex].push_back(distance);
+ }
+
+ return distance;
+}
+
+//! Single-tree scoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::Score(const size_t queryIndex,
+ TreeType& referenceNode)
+{
+ const math::Range distances =
+ referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+
+ // If the ranges do not overlap, prune this node.
+ if (!distances.Contains(range))
+ return DBL_MAX;
+
+ // In this case, all of the points in the reference node will be part of the
+ // results.
+ if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+ {
+ AddResult(queryIndex, referenceNode);
+ return DBL_MAX; // We don't need to go any deeper.
+ }
+
+ // Otherwise the score doesn't matter. Recursion order is irrelevant in range
+ // search.
+ return 0.0;
+}
+
+//! Single-tree scoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules::Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult)
+{
+ const math::Range distances = referenceNode.RangeDistance(
+ querySet.unsafe_col(queryIndex), baseCaseResult);
+
+ // If the ranges do not overlap, prune this node.
+ if (!distances.Contains(range))
+ return DBL_MAX;
+
+ // In this case, all of the points in the reference node will be part of the
+ // results.
+ if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+ {
+ AddResult(queryIndex, referenceNode);
+ return DBL_MAX; // We don't need to go any deeper.
+ }
+
+ // Otherwise the score doesn't matter. Recursion order is irrelevant in range
+ // search.
+ return 0.0;
+}
+
+//! Single-tree rescoring function.
+template<typename MetricType, typename TreeType>
+double RangeSearchRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ // If it wasn't pruned before, it isn't pruned now.
+ return oldScore;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list