[mlpack-svn] r10520 - in mlpack/trunk/src/mlpack/methods: . range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Dec 2 21:17:32 EST 2011


Author: rcurtin
Date: 2011-12-02 21:17:32 -0500 (Fri, 02 Dec 2011)
New Revision: 10520

Added:
   mlpack/trunk/src/mlpack/methods/range_search/
   mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
Log:
Add RangeSearch to the list of methods we have.


Added: mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/CMakeLists.txt	2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,17 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  range_search.hpp
+  range_search_impl.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)

Added: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,276 @@
+/**
+ * @file range_search.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the RangeSearch class, which performs a generalized range search on
+ * points.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace range /** Range-search routines. */ {
+
+/**
+ * The RangeSearch class is a template class for performing range searches.
+ */
+template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+                                                   tree::EmptyStatistic> >
+class RangeSearch
+{
+ public:
+  /**
+   * Initialize the RangeSearch object with a different reference set and a
+   * query set.  Optionally, perform the computation in naive mode or
+   * single-tree mode, and set the leaf size used for tree-building.
+   * Additionally, an instantiated metric can be given, for cases where the
+   * distance metric holds data.
+   *
+   * This method will copy the matrices to internal copies, which are rearranged
+   * during tree-building.  You can avoid this extra copy by pre-constructing
+   * the trees and passing them using a different constructor.
+   *
+   * @param referenceSet Reference dataset.
+   * @param querySet Query dataset.
+   * @param naive Whether the computation should be done in O(n^2) naive mode.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param leafSize The leaf size to be used during tree construction.
+   * @param metric Instantiated distance metric.
+   */
+  RangeSearch(const arma::mat& referenceSet,
+              const arma::mat& querySet,
+              const bool naive = false,
+              const bool singleMode = false,
+              const size_t leafSize = 20,
+              const MetricType metric = MetricType());
+
+  /**
+   * Initialize the RangeSearch object with only a reference set, which will
+   * also be used as a query set.  Optionally, perform the computation in naive
+   * mode or single-tree mode, and set the leaf size used for tree-building.
+   * Additionally an instantiated metric can be given, for cases where the
+   * distance metric holds data.
+   *
+   * This method will copy the reference matrix to an internal copy, which is
+   * rearranged during tree-building.  You can avoid this extra copy by
+   * pre-constructing the reference tree and passing it using a different
+   * constructor.
+   *
+   * @param referenceSet Reference dataset.
+   * @param naive Whether the computation should be done in O(n^2) naive mode.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param leafSize The leaf size to be used during tree construction.
+   * @param metric Instantiated distance metric.
+   */
+  RangeSearch(const arma::mat& referenceSet,
+              const bool naive = false,
+              const bool singleMode = false,
+              const size_t leafSize = 20,
+              const MetricType metric = MetricType());
+
+  /**
+   * Initialize the RangeSearch object with the given datasets and
+   * pre-constructed trees.  It is assumed that the points in referenceSet and
+   * querySet correspond to the points in referenceTree and queryTree,
+   * respectively.  Optionally, choose to use single-tree mode.  Naive
+   * mode is not available as an option for this constructor; instead, to run
+   * naive computation, construct a tree with all the points in one leaf (i.e.
+   * leafSize = number of points).  Additionally, an instantiated distance
+   * metric can be given, for cases where the distance metric holds data.
+   *
+   * There is no copying of the data matrices in this constructor (because
+   * tree-building is not necessary), so this is the constructor to use when
+   * copies absolutely must be avoided.
+   *
+   * @note
+   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+   * of a matrix, be sure you pass the modified matrix to this object!  In
+   * addition, mapping the points of the matrix back to their original indices
+   * is not done when this constructor is used.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param queryTree Pre-built tree for query points.
+   * @param referenceSet Set of reference points corresponding to referenceTree.
+   * @param querySet Set of query points corresponding to queryTree.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param metric Instantiated distance metric.
+   */
+  RangeSearch(const TreeType* referenceTree,
+              const TreeType* queryTree,
+              const arma::mat& referenceSet,
+              const arma::mat& querySet,
+              const bool singleMode = false,
+              const MetricType metric = MetricType());
+
+  /**
+   * Initialize the RangeSearch object with the given reference dataset and
+   * pre-constructed tree.  It is assumed that the points in referenceSet
+   * correspond to the points in referenceTree.  Optionally, choose to use
+   * single-tree mode.  Naive mode is not available as an option for this
+   * constructor; instead, to run naive computation, construct a tree with all
+   * the points in one leaf (i.e. leafSize = number of points).  Additionally,
+   * an instantiated distance metric can be given, for the case where the
+   * distance metric holds data.
+   *
+   * There is no copying of the data matrices in this constructor (because
+   * tree-building is not necessary), so this is the constructor to use when
+   * copies absolutely must be avoided.
+   *
+   * @note
+   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+   * of a matrix, be sure you pass the modified matrix to this object!  In
+   * addition, mapping the points of the matrix back to their original indices
+   * is not done when this constructor is used.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param referenceSet Set of reference points corresponding to referenceTree.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param metric Instantiated distance metric.
+   */
+  RangeSearch(const TreeType* referenceTree,
+              const arma::mat& referenceSet,
+              const bool singleMode = false,
+              const MetricType metric = MetricType());
+
+  /**
+   * Destroy the RangeSearch object.  If trees were created, they will be
+   * deleted.
+   */
+  ~RangeSearch();
+
+  /**
+   * Search for all points in the given range, returning the results in the
+   * neighbors and distances objects.  Each entry in the external vector
+   * corresponds to a query point.  Each of these entries holds a vector which
+   * contains the indices and distances of the reference points falling into the
+   * given range.
+   *
+   * That is:
+   *
+   * - neighbors.size() and distances.size() both equal the number of query
+   *   points.
+   *
+   * - neighbors[i] contains the indices of all the points in the reference set
+   *   which have distances inside the given range to query point i.
+   *
+   * - distances[i] contains all of the distances corresponding to the indices
+   *   contained in neighbors[i].
+   *
+   * - neighbors[i] and distances[i] are not sorted in any particular order.
+   *
+   * @param range Range of distances in which to search.
+   * @param neighbors Object which will hold the list of neighbors for each
+   *      point which fell into the given range, for each query point.
+   * @param distances Object which will hold the list of distances for each
+   *      point which fell into the given range, for each query point.
+   */
+  void Search(const math::Range& range,
+              std::vector<std::vector<size_t> >& neighbors,
+              std::vector<std::vector<double> >& distances);
+
+ private:
+  /**
+   * Compute the base case, when both referenceNode and queryNode are leaves
+   * containing points.
+   *
+   * @param referenceNode Reference node (must be a leaf).
+   * @param queryNode Query node (must be a leaf).
+   * @param range Range of distances to search for.
+   * @param neighbors Object holding list of neighbors.
+   * @param distances Object holding list of distances.
+   */
+  void ComputeBaseCase(const TreeType* referenceNode,
+                       const TreeType* queryNode,
+                       const math::Range& range,
+                       std::vector<std::vector<size_t> >& neighbors,
+                       std::vector<std::vector<double> >& distances) const;
+
+  /**
+   * Perform the dual-tree recursion, which will recurse until the base case is
+   * necessary.
+   *
+   * @param referenceNode Reference node.
+   * @param queryNode Query node.
+   * @param range Range of distances to search for.
+   * @param neighbors Object holding list of neighbors.
+   * @param distances Object holding list of distances.
+   */
+  void DualTreeRecursion(const TreeType* referenceNode,
+                         const TreeType* queryNode,
+                         const math::Range& range,
+                         std::vector<std::vector<size_t> >& neighbors,
+                         std::vector<std::vector<double> >& distances);
+
+  /**
+   * Perform the single-tree recursion, which will recurse down the reference
+   * tree to get the results for a single point.
+   *
+   * @param referenceNode Reference node.
+   * @param queryPoint Point to query for.
+   * @param range Range of distances to search for.
+   * @param neighbors Object holding list of neighbors.
+   * @param distances Object holding list of distances.
+   */
+  void SingleTreeRecursion(const TreeType* referenceNode,
+                           const arma::vec& queryPoint,
+                           const math::Range& range,
+                           std::vector<size_t>& neighbors,
+                           std::vector<double>& distances);
+
+  //! Copy of reference matrix; used when a tree is built internally.
+  arma::mat referenceCopy;
+  //! Copy of query matrix; used when a tree is built internally.
+  arma::mat queryCopy;
+
+  //! Reference set (data should be accessed using this).
+  const arma::mat& referenceSet;
+  //! Query set (data should be accessed using this).
+  const arma::mat& querySet;
+
+  //! Reference tree.
+  TreeType* referenceTree;
+  //! Query tree (may be NULL).
+  TreeType* queryTree;
+
+  //! Mappings to old reference indices (used when this object builds trees).
+  std::vector<size_t> oldFromNewReferences;
+  //! Mappings to old query indices (used when this object builds trees).
+  std::vector<size_t> oldFromNewQueries;
+
+  //! Indicates ownership of the reference tree (meaning we need to delete it).
+  bool ownReferenceTree;
+  //! Indicates ownership of the query tree (meaning we need to delete it).
+  bool ownQueryTree;
+
+  //! If true, O(n^2) naive computation is used.
+  bool naive;
+  //! If true, single-tree computation is used.
+  bool singleMode;
+
+  //! Instantiated distance metric.
+  MetricType metric;
+
+  //! The number of pruned nodes during computation.
+  size_t numberOfPrunes;
+};
+
+}; // namespace range
+}; // namespace mlpack
+
+// Include implementation.
+#include "range_search_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	2011-12-03 02:17:32 UTC (rev 10520)
@@ -0,0 +1,452 @@
+/**
+ * @file range_search_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the RangeSearch class.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+
+// Just in case it hasn't been included.
+#include "range_search.hpp"
+
+namespace mlpack {
+namespace range {
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
+                                               const arma::mat& querySet,
+                                               const bool naive,
+                                               const bool singleMode,
+                                               const size_t leafSize,
+                                               const MetricType metric) :
+    referenceCopy(referenceSet),
+    queryCopy(querySet),
+    referenceSet(referenceCopy),
+    querySet(queryCopy),
+    ownReferenceTree(true),
+    ownQueryTree(true),
+    naive(naive),
+    singleMode(!naive && singleMode), // Naive overrides single mode.
+    metric(metric),
+    numberOfPrunes(0)
+{
+  // Build the trees.
+  Timers::StartTimer("range_search/tree_building");
+
+  // Naive sets the leaf size such that the entire tree is one node.
+  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+      (naive ? referenceCopy.n_cols : leafSize));
+
+  queryTree = new TreeType(queryCopy, oldFromNewQueries,
+      (naive ? queryCopy.n_cols : leafSize));
+
+  Timers::StopTimer("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const arma::mat& referenceSet,
+                                               const bool naive,
+                                               const bool singleMode,
+                                               const size_t leafSize,
+                                               const MetricType metric) :
+    referenceCopy(referenceSet),
+    referenceSet(referenceCopy),
+    querySet(referenceCopy),
+    queryTree(NULL),
+    ownReferenceTree(true),
+    ownQueryTree(false),
+    naive(naive),
+    singleMode(!naive && singleMode), // Naive overrides single mode.
+    metric(metric),
+    numberOfPrunes(0)
+{
+  // Build the trees.
+  Timers::StartTimer("range_search/tree_building");
+
+  // Naive sets the leaf size such that the entire tree is one node.
+  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+      (naive ? referenceCopy.n_cols : leafSize));
+
+  Timers::StopTimer("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
+                                               const TreeType* queryTree,
+                                               const arma::mat& referenceSet,
+                                               const arma::mat& querySet,
+                                               const bool singleMode,
+                                               const MetricType metric) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    referenceTree(referenceTree),
+    queryTree(queryTree),
+    ownReferenceTree(false),
+    ownQueryTree(false),
+    naive(false),
+    singleMode(singleMode),
+    metric(metric),
+    numberOfPrunes(0)
+{
+  // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(const TreeType* referenceTree,
+                                               const arma::mat& referenceSet,
+                                               const bool singleMode,
+                                               const MetricType metric) :
+    referenceSet(referenceSet),
+    querySet(referenceSet),
+    referenceTree(referenceTree),
+    queryTree(NULL),
+    ownReferenceTree(false),
+    ownQueryTree(false),
+    naive(false),
+    singleMode(singleMode),
+    metric(metric),
+    numberOfPrunes(0)
+{
+  // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::~RangeSearch()
+{
+  if (ownReferenceTree)
+    delete referenceTree;
+  if (ownQueryTree)
+    delete queryTree;
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+    const math::Range& range,
+    std::vector<std::vector<size_t> >& neighbors,
+    std::vector<std::vector<double> >& distances)
+{
+  Timers::StartTimer("range_search/computing_neighbors");
+
+  // Set size of prunes to 0.
+  numberOfPrunes = 0;
+
+  // If we have built the trees ourselves, then we will have to map all the
+  // indices back to their original indices when this computation is finished.
+  // To avoid extra copies, we will store the unmapped neighbors and distances
+  // in a separate object.
+  std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
+  std::vector<std::vector<double> >* distancePtr = &distances;
+
+  if (ownQueryTree || (ownReferenceTree && !queryTree))
+    distancePtr = new std::vector<std::vector<double> >;
+  if (ownReferenceTree || ownQueryTree)
+    neighborPtr = new std::vector<std::vector<size_t> >;
+
+  // Resize each vector.
+  neighborPtr->clear(); // Just in case there was anything in it.
+  neighborPtr->resize(querySet.n_cols);
+  distancePtr->clear();
+  distancePtr->resize(querySet.n_cols);
+
+  if (naive)
+  {
+    // Run the base case.
+    if (!queryTree)
+      ComputeBaseCase(referenceTree, referenceTree, range, *neighborPtr,
+          *distancePtr);
+    else
+      ComputeBaseCase(referenceTree, queryTree, range, *neighborPtr,
+          *distancePtr);
+  }
+  else if (singleMode)
+  {
+    // Loop over each of the query points.
+    for (size_t i = 0; i < querySet.n_cols; i++)
+    {
+      SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), range,
+          (*neighborPtr)[i], (*distancePtr)[i]);
+    }
+  }
+  else
+  {
+    if (!queryTree) // References are the same as queries.
+      DualTreeRecursion(referenceTree, referenceTree, range, *neighborPtr,
+          *distancePtr);
+    else
+      DualTreeRecursion(referenceTree, queryTree, range, *neighborPtr,
+          *distancePtr);
+  }
+
+  Timers::StopTimer("range_search/computing_neighbors");
+
+  // Output number of prunes.
+  Log::Info << "Number of pruned nodes during computation: " << numberOfPrunes
+      << "." << std::endl;
+
+  // Map points back to original indices, if necessary.
+  if (!ownReferenceTree && !ownQueryTree)
+  {
+    // No mapping needed.  We are done.
+    return;
+  }
+  else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+  {
+    neighbors.clear();
+    neighbors.resize(querySet.n_cols);
+    distances.clear();
+    distances.resize(querySet.n_cols);
+
+    for (size_t i = 0; i < distances.size(); i++)
+    {
+      // Map distances (copy a column).
+      size_t queryMapping = oldFromNewQueries[i];
+      distances[queryMapping] = (*distancePtr)[i];
+
+      // Copy each neighbor individually, because we need to map it.
+      neighbors[queryMapping].resize(distances[queryMapping].size());
+      for (size_t j = 0; j < distances[queryMapping].size(); j++)
+      {
+        neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+      }
+    }
+
+    // Finished with temporary objects.
+    delete neighborPtr;
+    delete distancePtr;
+  }
+  else if (ownReferenceTree)
+  {
+    if (!queryTree) // No query tree -- map both references and queries.
+    {
+      neighbors.clear();
+      neighbors.resize(querySet.n_cols);
+      distances.clear();
+      distances.resize(querySet.n_cols);
+
+      for (size_t i = 0; i < distances.size(); i++)
+      {
+        // Map distances (copy a column).
+        size_t refMapping = oldFromNewReferences[i];
+        distances[refMapping] = (*distancePtr)[i];
+
+        // Copy each neighbor individually, because we need to map it.
+        neighbors[refMapping].resize(distances[refMapping].size());
+        for (size_t j = 0; j < distances[refMapping].size(); j++)
+        {
+          neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+        }
+      }
+
+      // Finished with temporary objects.
+      delete neighborPtr;
+      delete distancePtr;
+    }
+    else // Map only references.
+    {
+      neighbors.clear();
+      neighbors.resize(querySet.n_cols);
+
+      // Map indices of neighbors.
+      for (size_t i = 0; i < neighbors.size(); i++)
+      {
+        neighbors[i].resize((*neighborPtr)[i].size());
+        for (size_t j = 0; j < neighbors[i].size(); j++)
+        {
+          neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+        }
+      }
+
+      // Finished with temporary object.
+      delete neighborPtr;
+    }
+  }
+  else if (ownQueryTree)
+  {
+    neighbors.clear();
+    neighbors.resize(querySet.n_cols);
+    distances.clear();
+    distances.resize(querySet.n_cols);
+
+    for (size_t i = 0; i < distances.size(); i++)
+    {
+      // Map distances (copy a column).
+      distances[oldFromNewQueries[i]] = (*distancePtr)[i];
+
+      // Map neighbors.
+      neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
+    }
+
+    // Finished with temporary objects.
+    delete neighborPtr;
+    delete distancePtr;
+  }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::ComputeBaseCase(
+    const TreeType* referenceNode,
+    const TreeType* queryNode,
+    const math::Range& range,
+    std::vector<std::vector<size_t> >& neighbors,
+    std::vector<std::vector<double> >& distances) const
+{
+  // node->Begin() is the index of the first point in the node,
+  // node->End() is one past the last index.
+  for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+       queryIndex++)
+  {
+    // Get the query point from the matrix.
+    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+
+    double minDistance = referenceNode->Bound().MinDistance(queryPoint);
+    double maxDistance = referenceNode->Bound().MaxDistance(queryPoint);
+
+    // Now see if any points could fall into the range.
+    if (range.Contains(math::Range(minDistance, maxDistance)))
+    {
+      // Loop through the reference points and see which fall into the range.
+      for (size_t referenceIndex = referenceNode->Begin();
+          referenceIndex < referenceNode->End(); referenceIndex++)
+      {
+        // We can't add points that are ourselves.
+        if (referenceNode != queryNode || referenceIndex != queryIndex)
+        {
+          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+
+          double distance = metric.Evaluate(queryPoint, referencePoint);
+
+          // If this lies in the range, add it.
+          if (range.Contains(distance))
+          {
+            neighbors[queryIndex].push_back(referenceIndex);
+            distances[queryIndex].push_back(distance);
+          }
+        }
+      }
+    }
+  }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::DualTreeRecursion(
+    const TreeType* referenceNode,
+    const TreeType* queryNode,
+    const math::Range& range,
+    std::vector<std::vector<size_t> >& neighbors,
+    std::vector<std::vector<double> >& distances)
+{
+  // See if we can prune this node.
+  math::Range distance =
+      referenceNode->Bound().RangeDistance(queryNode->Bound());
+
+  if (!range.Contains(distance))
+  {
+    numberOfPrunes++; // Don't recurse.  These nodes can't contain anything.
+    return;
+  }
+
+  // If both nodes are leaves, then we compute the base case.
+  if (referenceNode->IsLeaf() && queryNode->IsLeaf())
+  {
+    ComputeBaseCase(referenceNode, queryNode, range, neighbors, distances);
+  }
+  else if (referenceNode->IsLeaf())
+  {
+    // We must descend down the query node to get a leaf.
+    DualTreeRecursion(referenceNode, queryNode->Left(), range, neighbors,
+        distances);
+    DualTreeRecursion(referenceNode, queryNode->Right(), range, neighbors,
+        distances);
+  }
+  else if (queryNode->IsLeaf())
+  {
+    // We must descend down the reference node to get a leaf.
+    DualTreeRecursion(referenceNode->Left(), queryNode, range, neighbors,
+        distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode, range, neighbors,
+        distances);
+  }
+  else
+  {
+    // First descend the left reference node.
+    DualTreeRecursion(referenceNode->Left(), queryNode->Left(), range,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Right(), range,
+        neighbors, distances);
+
+    // Now descend the right reference node.
+    DualTreeRecursion(referenceNode->Right(), queryNode->Left(), range,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Right(), range,
+        neighbors, distances);
+  }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
+    const TreeType* referenceNode,
+    const arma::vec& queryPoint,
+    const math::Range& range,
+    std::vector<size_t>& neighbors,
+    std::vector<double>& distances)
+{
+  // See if we need to recurse or if we can perform base-case computations.
+  if (referenceNode->IsLeaf())
+  {
+    // Base case: reference node is a leaf.
+    for (size_t referenceIndex = referenceNode->Begin(); referenceIndex !=
+         referenceNode->End(); referenceIndex++)
+    {
+      // Don't add this point if it is the same as the query point.
+      if (!(referenceSet.colptr(referenceIndex) == queryPoint.memptr()))
+      {
+        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+
+        double distance = metric.Evaluate(queryPoint, referencePoint);
+
+        // See if the point is in the range we are looking for.
+        if (range.Contains(distance))
+        {
+          neighbors.push_back(referenceIndex);
+          distances.push_back(distance);
+        }
+      }
+    }
+  }
+  else
+  {
+    // Recurse down the tree.
+    math::Range distanceLeft =
+        referenceNode->Left()->Bound().RangeDistance(queryPoint);
+    math::Range distanceRight =
+        referenceNode->Right()->Bound().RangeDistance(queryPoint);
+
+    if (range.Contains(distanceLeft))
+    {
+      // The left may have points we want to recurse to.
+      SingleTreeRecursion(referenceNode->Left(), queryPoint, range, neighbors,
+          distances);
+    }
+    else
+    {
+      numberOfPrunes++;
+    }
+
+    if (range.Contains(distanceRight))
+    {
+      // The right may have points we want to recurse to.
+      SingleTreeRecursion(referenceNode->Right(), queryPoint, range, neighbors,
+          distances);
+    }
+    else
+    {
+      numberOfPrunes++;
+    }
+  }
+}
+
+}; // namespace range
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list