[mlpack-svn] r10520 - in mlpack/trunk/src/mlpack/methods: . range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Dec 2 21:17:32 EST 2011
Author: rcurtin
Date: 2011-12-02 21:17:32 -0500 (Fri, 02 Dec 2011)
New Revision: 10520
Added:
mlpack/trunk/src/mlpack/methods/range_search/
mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
Log:
Add RangeSearch to the list of methods we have.
Added: mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt 2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,17 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ range_search.hpp
+ range_search_impl.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
Added: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp 2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,276 @@
+/**
+ * @file range_search.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the RangeSearch class, which performs a generalized range search on
+ * points.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace range /** Range-search routines. */ {
+
+/**
+ * The RangeSearch class is a template class for performing range searches.
+ */
+template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic> >
+class RangeSearch
+{
+ public:
+ /**
+ * Initialize the RangeSearch object with a different reference set and a
+ * query set. Optionally, perform the computation in naive mode or
+ * single-tree mode, and set the leaf size used for tree-building.
+ * Additionally, an instantiated metric can be given, for cases where the
+ * distance metric holds data.
+ *
+ * This method will copy the matrices to internal copies, which are rearranged
+ * during tree-building. You can avoid this extra copy by pre-constructing
+ * the trees and passing them using a different constructor.
+ *
+ * @param referenceSet Reference dataset.
+ * @param querySet Query dataset.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param leafSize The leaf size to be used during tree construction.
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with only a reference set, which will
+ * also be used as a query set. Optionally, perform the computation in naive
+ * mode or single-tree mode, and set the leaf size used for tree-building.
+ * Additionally an instantiated metric can be given, for cases where the
+ * distance metric holds data.
+ *
+ * This method will copy the reference matrix to an internal copy, which is
+ * rearranged during tree-building. You can avoid this extra copy by
+ * pre-constructing the reference tree and passing it using a different
+ * constructor.
+ *
+ * @param referenceSet Reference dataset.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param leafSize The leaf size to be used during tree construction.
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const arma::mat& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with the given datasets and
+ * pre-constructed trees. It is assumed that the points in referenceSet and
+ * querySet correspond to the points in referenceTree and queryTree,
+ * respectively. Optionally, choose to use single-tree mode. Naive
+ * mode is not available as an option for this constructor; instead, to run
+ * naive computation, construct a tree with all the points in one leaf (i.e.
+ * leafSize = number of points). Additionally, an instantiated distance
+ * metric can be given, for cases where the distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param queryTree Pre-built tree for query points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param querySet Set of query points corresponding to queryTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const TreeType* referenceTree,
+ const TreeType* queryTree,
+ const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with the given reference dataset and
+ * pre-constructed tree. It is assumed that the points in referenceSet
+ * correspond to the points in referenceTree. Optionally, choose to use
+ * single-tree mode. Naive mode is not available as an option for this
+ * constructor; instead, to run naive computation, construct a tree with all
+ * the points in one leaf (i.e. leafSize = number of points). Additionally,
+ * an instantiated distance metric can be given, for the case where the
+ * distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const TreeType* referenceTree,
+ const arma::mat& referenceSet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Destroy the RangeSearch object. If trees were created, they will be
+ * deleted.
+ */
+ ~RangeSearch();
+
+ /**
+ * Search for all points in the given range, returning the results in the
+ * neighbors and distances objects. Each entry in the external vector
+ * corresponds to a query point. Each of these entries holds a vector which
+ * contains the indices and distances of the reference points falling into the
+ * given range.
+ *
+ * That is:
+ *
+ * - neighbors.size() and distances.size() both equal the number of query
+ * points.
+ *
+ * - neighbors[i] contains the indices of all the points in the reference set
+ * which have distances inside the given range to query point i.
+ *
+ * - distances[i] contains all of the distances corresponding to the indices
+ * contained in neighbors[i].
+ *
+ * - neighbors[i] and distances[i] are not sorted in any particular order.
+ *
+ * @param range Range of distances in which to search.
+ * @param neighbors Object which will hold the list of neighbors for each
+ * point which fell into the given range, for each query point.
+ * @param distances Object which will hold the list of distances for each
+ * point which fell into the given range, for each query point.
+ */
+ void Search(const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances);
+
+ private:
+ /**
+ * Compute the base case, when both referenceNode and queryNode are leaves
+ * containing points.
+ *
+ * @param referenceNode Reference node (must be a leaf).
+ * @param queryNode Query node (must be a leaf).
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ void ComputeBaseCase(const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances) const;
+
+ /**
+ * Perform the dual-tree recursion, which will recurse until the base case is
+ * necessary.
+ *
+ * @param referenceNode Reference node.
+ * @param queryNode Query node.
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ void DualTreeRecursion(const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances);
+
+ /**
+ * Perform the single-tree recursion, which will recurse down the reference
+ * tree to get the results for a single point.
+ *
+ * @param referenceNode Reference node.
+ * @param queryPoint Point to query for.
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ void SingleTreeRecursion(const TreeType* referenceNode,
+ const arma::vec& queryPoint,
+ const math::Range& range,
+ std::vector<size_t>& neighbors,
+ std::vector<double>& distances);
+
+ //! Copy of reference matrix; used when a tree is built internally.
+ arma::mat referenceCopy;
+ //! Copy of query matrix; used when a tree is built internally.
+ arma::mat queryCopy;
+
+ //! Reference set (data should be accessed using this).
+ const arma::mat& referenceSet;
+ //! Query set (data should be accessed using this).
+ const arma::mat& querySet;
+
+ //! Reference tree.
+ TreeType* referenceTree;
+ //! Query tree (may be NULL).
+ TreeType* queryTree;
+
+ //! Mappings to old reference indices (used when this object builds trees).
+ std::vector<size_t> oldFromNewReferences;
+ //! Mappings to old query indices (used when this object builds trees).
+ std::vector<size_t> oldFromNewQueries;
+
+ //! Indicates ownership of the reference tree (meaning we need to delete it).
+ bool ownReferenceTree;
+ //! Indicates ownership of the query tree (meaning we need to delete it).
+ bool ownQueryTree;
+
+ //! If true, O(n^2) naive computation is used.
+ bool naive;
+ //! If true, single-tree computation is used.
+ bool singleMode;
+
+ //! Instantiated distance metric.
+ MetricType metric;
+
+ //! The number of pruned nodes during computation.
+ size_t numberOfPrunes;
+};
+
+}; // namespace range
+}; // namespace mlpack
+
+// Include implementation.
+#include "range_search_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp 2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,452 @@
+/**
+ * @file range_search_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the RangeSearch class.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+
+// Just in case it hasn't been included.
+#include "range_search.hpp"
+
+namespace mlpack {
+namespace range {
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ queryCopy(querySet),
+ referenceSet(referenceCopy),
+ querySet(queryCopy),
+ ownReferenceTree(true),
+ ownQueryTree(true),
+ naive(naive),
+ singleMode(!naive && singleMode), // Naive overrides single mode.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Build the trees.
+ Timers::StartTimer("range_search/tree_building");
+
+ // Naive sets the leaf size such that the entire tree is one node.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ queryTree = new TreeType(queryCopy, oldFromNewQueries,
+ (naive ? queryCopy.n_cols : leafSize));
+
+ Timers::StopTimer("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ referenceSet(referenceCopy),
+ querySet(referenceCopy),
+ queryTree(NULL),
+ ownReferenceTree(true),
+ ownQueryTree(false),
+ naive(naive),
+ singleMode(!naive && singleMode), // Naive overrides single mode.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Build the trees.
+ Timers::StartTimer("range_search/tree_building");
+
+ // Naive sets the leaf size such that the entire tree is one node.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ Timers::StopTimer("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
+ const TreeType* queryTree,
+ const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
+ const arma::mat& referenceSet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::~RangeSearch()
+{
+ if (ownReferenceTree)
+ delete referenceTree;
+ if (ownQueryTree)
+ delete queryTree;
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances)
+{
+ Timers::StartTimer("range_search/computing_neighbors");
+
+ // Set size of prunes to 0.
+ numberOfPrunes = 0;
+
+ // If we have built the trees ourselves, then we will have to map all the
+ // indices back to their original indices when this computation is finished.
+ // To avoid extra copies, we will store the unmapped neighbors and distances
+ // in a separate object.
+ std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
+ std::vector<std::vector<double> >* distancePtr = &distances;
+
+ if (ownQueryTree || (ownReferenceTree && !queryTree))
+ distancePtr = new std::vector<std::vector<double> >;
+ if (ownReferenceTree || ownQueryTree)
+ neighborPtr = new std::vector<std::vector<size_t> >;
+
+ // Resize each vector.
+ neighborPtr->clear(); // Just in case there was anything in it.
+ neighborPtr->resize(querySet.n_cols);
+ distancePtr->clear();
+ distancePtr->resize(querySet.n_cols);
+
+ if (naive)
+ {
+ // Run the base case.
+ if (!queryTree)
+ ComputeBaseCase(referenceTree, referenceTree, range, *neighborPtr,
+ *distancePtr);
+ else
+ ComputeBaseCase(referenceTree, queryTree, range, *neighborPtr,
+ *distancePtr);
+ }
+ else if (singleMode)
+ {
+ // Loop over each of the query points.
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), range,
+ (*neighborPtr)[i], (*distancePtr)[i]);
+ }
+ }
+ else
+ {
+ if (!queryTree) // References are the same as queries.
+ DualTreeRecursion(referenceTree, referenceTree, range, *neighborPtr,
+ *distancePtr);
+ else
+ DualTreeRecursion(referenceTree, queryTree, range, *neighborPtr,
+ *distancePtr);
+ }
+
+ Timers::StopTimer("range_search/computing_neighbors");
+
+ // Output number of prunes.
+ Log::Info << "Number of pruned nodes during computation: " << numberOfPrunes
+ << "." << std::endl;
+
+ // Map points back to original indices, if necessary.
+ if (!ownReferenceTree && !ownQueryTree)
+ {
+ // No mapping needed. We are done.
+ return;
+ }
+ else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ size_t queryMapping = oldFromNewQueries[i];
+ distances[queryMapping] = (*distancePtr)[i];
+
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[queryMapping].resize(distances[queryMapping].size());
+ for (size_t j = 0; j < distances[queryMapping].size(); j++)
+ {
+ neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (ownReferenceTree)
+ {
+ if (!queryTree) // No query tree -- map both references and queries.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ size_t refMapping = oldFromNewReferences[i];
+ distances[refMapping] = (*distancePtr)[i];
+
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[refMapping].resize(distances[refMapping].size());
+ for (size_t j = 0; j < distances[refMapping].size(); j++)
+ {
+ neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else // Map only references.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < neighbors.size(); i++)
+ {
+ neighbors[i].resize((*neighborPtr)[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ {
+ neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary object.
+ delete neighborPtr;
+ }
+ }
+ else if (ownQueryTree)
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ distances[oldFromNewQueries[i]] = (*distancePtr)[i];
+
+ // Map neighbors.
+ neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::ComputeBaseCase(
+ const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances) const
+{
+ // node->Begin() is the index of the first point in the node,
+ // node->End() is one past the last index.
+ for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+ queryIndex++)
+ {
+ // Get the query point from the matrix.
+ arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+
+ double minDistance = referenceNode->Bound().MinDistance(queryPoint);
+ double maxDistance = referenceNode->Bound().MaxDistance(queryPoint);
+
+ // Now see if any points could fall into the range.
+ if (range.Contains(math::Range(minDistance, maxDistance)))
+ {
+ // Loop through the reference points and see which fall into the range.
+ for (size_t referenceIndex = referenceNode->Begin();
+ referenceIndex < referenceNode->End(); referenceIndex++)
+ {
+ // We can't add points that are ourselves.
+ if (referenceNode != queryNode || referenceIndex != queryIndex)
+ {
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+
+ double distance = metric.Evaluate(queryPoint, referencePoint);
+
+ // If this lies in the range, add it.
+ if (range.Contains(distance))
+ {
+ neighbors[queryIndex].push_back(referenceIndex);
+ distances[queryIndex].push_back(distance);
+ }
+ }
+ }
+ }
+ }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::DualTreeRecursion(
+ const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances)
+{
+ // See if we can prune this node.
+ math::Range distance =
+ referenceNode->Bound().RangeDistance(queryNode->Bound());
+
+ if (!range.Contains(distance))
+ {
+ numberOfPrunes++; // Don't recurse. These nodes can't contain anything.
+ return;
+ }
+
+ // If both nodes are leaves, then we compute the base case.
+ if (referenceNode->IsLeaf() && queryNode->IsLeaf())
+ {
+ ComputeBaseCase(referenceNode, queryNode, range, neighbors, distances);
+ }
+ else if (referenceNode->IsLeaf())
+ {
+ // We must descend down the query node to get a leaf.
+ DualTreeRecursion(referenceNode, queryNode->Left(), range, neighbors,
+ distances);
+ DualTreeRecursion(referenceNode, queryNode->Right(), range, neighbors,
+ distances);
+ }
+ else if (queryNode->IsLeaf())
+ {
+ // We must descend down the reference node to get a leaf.
+ DualTreeRecursion(referenceNode->Left(), queryNode, range, neighbors,
+ distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode, range, neighbors,
+ distances);
+ }
+ else
+ {
+ // First descend the left reference node.
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(), range,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(), range,
+ neighbors, distances);
+
+ // Now descend the right reference node.
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(), range,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(), range,
+ neighbors, distances);
+ }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
+ const TreeType* referenceNode,
+ const arma::vec& queryPoint,
+ const math::Range& range,
+ std::vector<size_t>& neighbors,
+ std::vector<double>& distances)
+{
+ // See if we need to recurse or if we can perform base-case computations.
+ if (referenceNode->IsLeaf())
+ {
+ // Base case: reference node is a leaf.
+ for (size_t referenceIndex = referenceNode->Begin(); referenceIndex !=
+ referenceNode->End(); referenceIndex++)
+ {
+ // Don't add this point if it is the same as the query point.
+ if (!(referenceSet.colptr(referenceIndex) == queryPoint.memptr()))
+ {
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+
+ double distance = metric.Evaluate(queryPoint, referencePoint);
+
+ // See if the point is in the range we are looking for.
+ if (range.Contains(distance))
+ {
+ neighbors.push_back(referenceIndex);
+ distances.push_back(distance);
+ }
+ }
+ }
+ }
+ else
+ {
+ // Recurse down the tree.
+ math::Range distanceLeft =
+ referenceNode->Left()->Bound().RangeDistance(queryPoint);
+ math::Range distanceRight =
+ referenceNode->Right()->Bound().RangeDistance(queryPoint);
+
+ if (range.Contains(distanceLeft))
+ {
+ // The left may have points we want to recurse to.
+ SingleTreeRecursion(referenceNode->Left(), queryPoint, range, neighbors,
+ distances);
+ }
+ else
+ {
+ numberOfPrunes++;
+ }
+
+ if (range.Contains(distanceRight))
+ {
+ // The right may have points we want to recurse to.
+ SingleTreeRecursion(referenceNode->Right(), queryPoint, range, neighbors,
+ distances);
+ }
+ else
+ {
+ numberOfPrunes++;
+ }
+ }
+}
+
+}; // namespace range
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list