[mlpack-git] master: Preliminary refactoring of RangeSearch. (3673579)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 8 18:30:40 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/aaa19c9dcaed6e859d6565b8c1bfb07f755f0d7e...367357985828e5a007b3b9ccf6c279e983cd7fad

>---------------------------------------------------------------

commit 367357985828e5a007b3b9ccf6c279e983cd7fad
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Apr 8 18:30:24 2015 -0400

    Preliminary refactoring of RangeSearch.


>---------------------------------------------------------------

367357985828e5a007b3b9ccf6c279e983cd7fad
 src/mlpack/methods/range_search/range_search.hpp   | 159 ++++-----
 .../methods/range_search/range_search_impl.hpp     | 377 ++++++++++++---------
 .../methods/range_search/range_search_main.cpp     | 115 +++----
 src/mlpack/tests/range_search_test.cpp             | 167 ++++-----
 4 files changed, 428 insertions(+), 390 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 969166a..9504fd3 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -48,31 +48,6 @@ class RangeSearch
    * @param metric Instantiated distance metric.
    */
   RangeSearch(const typename TreeType::Mat& referenceSet,
-              const typename TreeType::Mat& querySet,
-              const bool naive = false,
-              const bool singleMode = false,
-              const MetricType metric = MetricType());
-
-  /**
-   * Initialize the RangeSearch object with only a reference set, which will
-   * also be used as a query set.  Optionally, perform the computation in naive
-   * mode or single-tree mode, and set the leaf size used for tree-building.
-   * Additionally an instantiated metric can be given, for cases where the
-   * distance metric holds data.
-   *
-   * This method will copy the reference matrix to an internal copy, which is
-   * rearranged during tree-building.  You can avoid this extra copy by
-   * pre-constructing the reference tree and passing it using a different
-   * constructor.
-   *
-   * @param referenceSet Reference dataset.
-   * @param naive Whether the computation should be done in O(n^2) naive mode.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param leafSize The leaf size to be used during tree construction.
-   * @param metric Instantiated distance metric.
-   */
-  RangeSearch(const typename TreeType::Mat& referenceSet,
               const bool naive = false,
               const bool singleMode = false,
               const MetricType metric = MetricType());
@@ -107,56 +82,97 @@ class RangeSearch
    * @param metric Instantiated distance metric.
    */
   RangeSearch(TreeType* referenceTree,
-              TreeType* queryTree,
-              const typename TreeType::Mat& referenceSet,
-              const typename TreeType::Mat& querySet,
               const bool singleMode = false,
               const MetricType metric = MetricType());
 
   /**
-   * Initialize the RangeSearch object with the given reference dataset and
-   * pre-constructed tree.  It is assumed that the points in referenceSet
-   * correspond to the points in referenceTree.  Optionally, choose to use
-   * single-tree mode.  Naive mode is not available as an option for this
-   * constructor; instead, to run naive computation, construct a tree with all
-   * the points in one leaf (i.e. leafSize = number of points).  Additionally,
-   * an instantiated distance metric can be given, for the case where the
-   * distance metric holds data.
+   * Destroy the RangeSearch object.  If trees were created, they will be
+   * deleted.
+   */
+  ~RangeSearch();
+
+  /**
+   * Search for all reference points in the given range for each point in the
+   * query set, returning the results in the neighbors and distances objects.
+   * Each entry in the external vector corresponds to a query point.  Each of
+   * these entries holds a vector which contains the indices and distances of
+   * the reference points falling into the given range.
    *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
+   * That is:
    *
-   * @note
-   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
-   * of a matrix, be sure you pass the modified matrix to this object!  In
-   * addition, mapping the points of the matrix back to their original indices
-   * is not done when this constructor is used.
-   * @endnote
+   * - neighbors.size() and distances.size() both equal the number of query
+   *   points.
    *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param referenceSet Set of reference points corresponding to referenceTree.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param metric Instantiated distance metric.
+   * - neighbors[i] contains the indices of all the points in the reference set
+   *   which have distances inside the given range to query point i.
+   *
+   * - distances[i] contains all of the distances corresponding to the indices
+   *   contained in neighbors[i].
+   *
+   * - neighbors[i] and distances[i] are not sorted in any particular order.
+   *
+   * @param querySet Set of query points to search with.
+   * @param range Range of distances in which to search.
+   * @param neighbors Object which will hold the list of neighbors for each
+   *      point which fell into the given range, for each query point.
+   * @param distances Object which will hold the list of distances for each
+   *      point which fell into the given range, for each query point.
    */
-  RangeSearch(TreeType* referenceTree,
-              const typename TreeType::Mat& referenceSet,
-              const bool singleMode = false,
-              const MetricType metric = MetricType());
+  void Search(const typename TreeType::Mat& querySet,
+              const math::Range& range,
+              std::vector<std::vector<size_t>>& neighbors,
+              std::vector<std::vector<double>>& distances);
 
   /**
-   * Destroy the RangeSearch object.  If trees were created, they will be
-   * deleted.
+   * Given a pre-built query tree, search for all reference points in the given
+   * range for each point in the query set, returning the results in the
+   * neighbors and distances objects.
+   *
+   * Each entry in the external vector corresponds to a query point.  Each of
+   * these entries holds a vector which contains the indices and distances of
+   * the reference points falling into the given range.
+   *
+   * That is:
+   *
+   * - neighbors.size() and distances.size() both equal the number of query
+   *   points.
+   *
+   * - neighbors[i] contains the indices of all the points in the reference set
+   *   which have distances inside the given range to query point i.
+   *
+   * - distances[i] contains all of the distances corresponding to the indices
+   *   contained in neighbors[i].
+   *
+   * - neighbors[i] and distances[i] are not sorted in any particular order.
+   *
+   * If either naive or singleMode are set to true, this will throw an
+   * invalid_argument exception; passing in a query tree implies dual-tree
+   * search.
+   *
+   * If you want to use the reference tree as the query tree, instead call the
+   * overload of Search() that does not take a query set.
+   *
+   * @param queryTree Tree built on query points.
+   * @param range Range of distances in which to search.
+   * @param neighbors Object which will hold the list of neighbors for each
+   *      point which fell into the given range, for each query point.
+   * @param distances Object which will hold the list of distances for each
+   *      point which fell into the given range, for each query point.
    */
-  ~RangeSearch();
+  void Search(TreeType* queryTree,
+              const math::Range& range,
+              std::vector<std::vector<size_t>>& neighbors,
+              std::vector<std::vector<double>>& distances);
 
   /**
-   * Search for all points in the given range, returning the results in the
-   * neighbors and distances objects.  Each entry in the external vector
-   * corresponds to a query point.  Each of these entries holds a vector which
-   * contains the indices and distances of the reference points falling into the
-   * given range.
+   * Search for all points in the given range for each point in the reference
+   * set (which was passed to the constructor), returning the results in the
+   * neighbors and distances objects.  This means that the query set and the
+   * reference set are the same.
+   *
+   * Each entry in the external vector corresponds to a query point.  Each of
+   * these entries holds a vector which contains the indices and distances of
+   * the reference points falling into the given range.
    *
    * That is:
    *
@@ -171,6 +187,7 @@ class RangeSearch
    *
    * - neighbors[i] and distances[i] are not sorted in any particular order.
    *
+   * @param queryTree Tree built on query points.
    * @param range Range of distances in which to search.
    * @param neighbors Object which will hold the list of neighbors for each
    *      point which fell into the given range, for each query point.
@@ -178,8 +195,8 @@ class RangeSearch
    *      point which fell into the given range, for each query point.
    */
   void Search(const math::Range& range,
-              std::vector<std::vector<size_t> >& neighbors,
-              std::vector<std::vector<double> >& distances);
+              std::vector<std::vector<size_t>>& neighbors,
+              std::vector<std::vector<double>>& distances);
 
   // Returns a string representation of this object.
   std::string ToString() const;
@@ -187,23 +204,12 @@ class RangeSearch
  private:
   //! Copy of reference matrix; used when a tree is built internally.
   typename TreeType::Mat referenceCopy;
-  //! Copy of query matrix; used when a tree is built internally.
-  typename TreeType::Mat queryCopy;
-
   //! Reference set (data should be accessed using this).
   const typename TreeType::Mat& referenceSet;
-  //! Query set (data should be accessed using this).
-  const typename TreeType::Mat& querySet;
-
   //! Reference tree.
   TreeType* referenceTree;
-  //! Query tree (may be NULL).
-  TreeType* queryTree;
-
   //! Mappings to old reference indices (used when this object builds trees).
   std::vector<size_t> oldFromNewReferences;
-  //! Mappings to old query indices (used when this object builds trees).
-  std::vector<size_t> oldFromNewQueries;
 
   //! If true, this object is responsible for deleting the trees.
   bool treeOwner;
@@ -218,9 +224,6 @@ class RangeSearch
 
   //! Instantiated distance metric.
   MetricType metric;
-
-  //! The number of pruned nodes during computation.
-  size_t numPrunes;
 };
 
 }; // namespace range
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 38e76f1..ae443ac 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -42,179 +42,91 @@ TreeType* BuildTree(
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
     const typename TreeType::Mat& referenceSetIn,
-    const typename TreeType::Mat& querySetIn,
     const bool naive,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
-        : querySetIn),
+    referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+        ? referenceCopy : referenceSetIn),
     referenceTree(NULL),
-    queryTree(NULL),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
-    hasQuerySet(true),
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
-    metric(metric),
-    numPrunes(0)
+    metric(metric)
 {
   // Build the trees.
   Timer::Start("range_search/tree_building");
 
-  // Copy the datasets, if they will be modified during tree building.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-  {
-    referenceCopy = referenceSetIn;
-    queryCopy = querySetIn;
-  }
-
   // If in naive mode, then we do not need to build trees.
   if (!naive)
   {
-    // The const_cast is safe; if RearrangesDataset == false, then it'll be
-    // casted back to const anyway, and if not, referenceSet points to
-    // referenceCopy, which isn't const.
-    referenceTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(referenceSet),
-        oldFromNewReferences);
-
-    if (!singleMode)
-      queryTree = BuildTree<TreeType>(
-          const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
-  }
+    // Copy the dataset, if it will be modified during tree building.
+    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      referenceCopy = referenceSetIn;
 
-  Timer::Stop("range_search/tree_building");
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
-    const typename TreeType::Mat& referenceSetIn,
-    const bool naive,
-    const bool singleMode,
-    const MetricType metric) :
-    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    referenceTree(NULL),
-    queryTree(NULL),
-    treeOwner(!naive), // If in naive mode, we are not building any trees.
-    hasQuerySet(false),
-    naive(naive),
-    singleMode(!naive && singleMode), // Naive overrides single mode.
-    metric(metric),
-    numPrunes(0)
-{
-  // Build the trees.
-  Timer::Start("range_search/tree_building");
-
-  // Copy the dataset, if it will be modified during tree building.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-    referenceCopy = referenceSetIn;
-
-  // If in naive mode, then we do not need to build trees.
-  if (!naive)
-  {
     // The const_cast is safe; if RearrangesDataset == false, then it'll be
     // casted back to const anyway, and if not, referenceSet points to
     // referenceCopy, which isn't const.
     referenceTree = BuildTree<TreeType>(
         const_cast<typename TreeType::Mat&>(referenceSet),
         oldFromNewReferences);
-
-    if (!singleMode)
-      queryTree = new TreeType(*referenceTree);
   }
+
   Timer::Stop("range_search/tree_building");
 }
 
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
     TreeType* referenceTree,
-    TreeType* queryTree,
-    const typename TreeType::Mat& referenceSet,
-    const typename TreeType::Mat& querySet,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet(referenceSet),
-    querySet(querySet),
+    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
-    queryTree(queryTree),
     treeOwner(false),
-    hasQuerySet(true),
     naive(false),
     singleMode(singleMode),
-    metric(metric),
-    numPrunes(0)
+    metric(metric)
 {
   // Nothing else to initialize.
 }
 
 template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
-    TreeType* referenceTree,
-    const typename TreeType::Mat& referenceSet,
-    const bool singleMode,
-    const MetricType metric) :
-    referenceSet(referenceSet),
-    querySet(referenceSet),
-    referenceTree(referenceTree),
-    queryTree(NULL),
-    treeOwner(false),
-    hasQuerySet(false),
-    naive(false),
-    singleMode(singleMode),
-    metric(metric),
-    numPrunes(0)
-{
-  // If doing dual-tree range search, we must clone the reference tree.
-  if (!singleMode)
-    queryTree = new TreeType(*referenceTree);
-}
-
-template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::~RangeSearch()
 {
-  if (treeOwner)
-  {
-    if (referenceTree)
-      delete referenceTree;
-    if (queryTree)
-      delete queryTree;
-  }
-
-  // If doing dual-tree search with one dataset, we cloned the reference tree.
-  if (!treeOwner && !hasQuerySet && !(singleMode || naive))
-    delete queryTree;
+  if (treeOwner && referenceTree)
+    delete referenceTree;
 }
 
 template<typename MetricType, typename TreeType>
 void RangeSearch<MetricType, TreeType>::Search(
+    const typename TreeType::Mat& querySet,
     const math::Range& range,
-    std::vector<std::vector<size_t> >& neighbors,
-    std::vector<std::vector<double> >& distances)
+    std::vector<std::vector<size_t>>& neighbors,
+    std::vector<std::vector<double>>& distances)
 {
   Timer::Start("range_search/computing_neighbors");
 
-  // Set size of prunes to 0.
-  numPrunes = 0;
+  // This will hold mappings for query points, if necessary.
+  std::vector<size_t> oldFromNewQueries;
 
   // If we have built the trees ourselves, then we will have to map all the
   // indices back to their original indices when this computation is finished.
   // To avoid extra copies, we will store the unmapped neighbors and distances
   // in a separate object.
-  std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
-  std::vector<std::vector<double> >* distancePtr = &distances;
+  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+  std::vector<std::vector<double>>* distancePtr = &distances;
 
   // Mapping is only necessary if the tree rearranges points.
   if (tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    if (treeOwner && !(singleMode && hasQuerySet))
-      distancePtr = new std::vector<std::vector<double> >; // Query indices need to be mapped.
+    // Query indices only need to be mapped if we are building the query tree
+    // ourselves.
+    if (!singleMode && !naive)
+      distancePtr = new std::vector<std::vector<double>>;
 
+    // Reference indices only need to be mapped if we built the reference tree
+    // ourselves.
     if (treeOwner)
-      neighborPtr = new std::vector<std::vector<size_t> >; // All indices need mapping.
+      neighborPtr = new std::vector<std::vector<size_t>>;
   }
 
   // Resize each vector.
@@ -243,63 +155,224 @@ void RangeSearch<MetricType, TreeType>::Search(
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
-
-    numPrunes = traverser.NumPrunes();
   }
   else // Dual-tree recursion.
   {
+    // Build the query tree.
+    Timer::Stop("range_search/computing_neighbors");
+    Timer::Start("range_search/tree_building");
+    typename TreeType::Mat queryCopy;
+    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      queryCopy = querySet;
+
+    const typename TreeType::Mat& querySetRef =
+        (tree::TreeTraits<TreeType>::RearrangesDataset) ? querySet : queryCopy;
+    TreeType* queryTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+    Timer::Stop("range_search/tree_building");
+    Timer::Start("range_search/computing_neighbors");
+
     // Create the traverser.
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*queryTree, *referenceTree);
 
-    numPrunes = traverser.NumPrunes();
+    // Clean up tree memory.
+    delete queryTree;
   }
 
   Timer::Stop("range_search/computing_neighbors");
 
-  // Output number of prunes.
-  Log::Info << "Number of pruned nodes during computation: " << numPrunes
-      << "." << std::endl;
-
   // Map points back to original indices, if necessary.
-
-  if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    // No mapping needed.  We are done.
-    return;
+    if (!singleMode && !naive && treeOwner)
+    {
+      // We must map both query and reference indices.
+      neighbors.clear();
+      neighbors.resize(querySet.n_cols);
+      distances.clear();
+      distances.resize(querySet.n_cols);
+
+      for (size_t i = 0; i < distances.size(); i++)
+      {
+        // Map distances (copy a column).
+        const size_t queryMapping = oldFromNewQueries[i];
+        distances[queryMapping] = (*distancePtr)[i];
+
+        // Copy each neighbor individually, because we need to map it.
+        neighbors[queryMapping].resize(distances[queryMapping].size());
+        for (size_t j = 0; j < distances[queryMapping].size(); j++)
+          neighbors[queryMapping][j] =
+              oldFromNewReferences[(*neighborPtr)[i][j]];
+      }
+
+      // Finished with temporary objects.
+      delete neighborPtr;
+      delete distancePtr;
+    }
+    else if (!singleMode && !naive)
+    {
+      // We must map query indices only.
+      neighbors.clear();
+      neighbors.resize(querySet.n_cols);
+      distances.clear();
+      distances.resize(querySet.n_cols);
+
+      for (size_t i = 0; i < distances.size(); ++i)
+      {
+        // Map distances and neighbors (copy a column).
+        const size_t queryMapping = oldFromNewQueries[i];
+        distances[queryMapping] = (*distancePtr)[i];
+        neighbors[queryMapping] = (*neighborPtr)[i];
+      }
+
+      // Finished with temporary objects.
+      delete neighborPtr;
+      delete distancePtr;
+    }
+    else if (treeOwner)
+    {
+      // We must map reference indices only.
+      neighbors.clear();
+      neighbors.resize(querySet.n_cols);
+
+      for (size_t i = 0; i < neighbors.size(); i++)
+      {
+        neighbors[i].resize((*neighborPtr)[i].size());
+        for (size_t j = 0; j < neighbors[i].size(); j++)
+          neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+      }
+
+      // Finished with temporary object.
+      delete neighborPtr;
+    }
   }
-  else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+    TreeType* queryTree,
+    const math::Range& range,
+    std::vector<std::vector<size_t>>& neighbors,
+    std::vector<std::vector<double>>& distances)
+{
+  Timer::Start("range_search/computing_neighbors");
+
+  // Get a reference to the query set.
+  const typename TreeType::Mat& querySet = queryTree->Dataset();
+
+  // Make sure we are in dual-tree mode.
+  if (singleMode || naive)
+    throw std::invalid_argument("cannot call RangeSearch::Search() with a "
+        "query tree when naive or singleMode are set to true");
+
+  // We won't need to map query indices, but will we need to map distances?
+  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+    neighborPtr = new std::vector<std::vector<size_t>>;
+
+  // Resize each vector.
+  neighborPtr->clear(); // Just in case there was anything in it.
+  neighborPtr->resize(querySet.n_cols);
+  distances.clear();
+  distances.resize(querySet.n_cols);
+
+  // Create the helper object for the traversal.
+  typedef RangeSearchRules<MetricType, TreeType> RuleType;
+  RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+      distances, metric);
+
+  // Create the traverser.
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+  traverser.Traverse(*queryTree, *referenceTree);
+
+  Timer::Stop("range_search/computing_neighbors");
+
+  // Do we need to map indices?
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
   {
+    // We must map reference indices only.
     neighbors.clear();
     neighbors.resize(querySet.n_cols);
-    distances.clear();
-    distances.resize(querySet.n_cols);
 
-    for (size_t i = 0; i < distances.size(); i++)
+    for (size_t i = 0; i < neighbors.size(); i++)
     {
-      // Map distances (copy a column).
-      size_t queryMapping = oldFromNewQueries[i];
-      distances[queryMapping] = (*distancePtr)[i];
-
-      // Copy each neighbor individually, because we need to map it.
-      neighbors[queryMapping].resize(distances[queryMapping].size());
-      for (size_t j = 0; j < distances[queryMapping].size(); j++)
-      {
-        neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
-      }
+      neighbors[i].resize((*neighborPtr)[i].size());
+      for (size_t j = 0; j < neighbors[i].size(); j++)
+        neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
     }
 
-    // Finished with temporary objects.
+    // Finished with temporary object.
     delete neighborPtr;
-    delete distancePtr;
   }
-  else if (treeOwner && !hasQuerySet)
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+    const math::Range& range,
+    std::vector<std::vector<size_t>>& neighbors,
+    std::vector<std::vector<double>>& distances)
+{
+  Timer::Start("range_search/computing_neighbors");
+
+  // Here, we will use the query set as the reference set.
+  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+  std::vector<std::vector<double>>* distancePtr = &distances;
+
+  if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+  {
+    // We will always need to rearrange in this case.
+    distancePtr = new std::vector<std::vector<double>>;
+    neighborPtr = new std::vector<std::vector<size_t>>;
+  }
+
+  // Resize each vector.
+  neighborPtr->clear(); // Just in case there was anything in it.
+  neighborPtr->resize(referenceSet.n_cols);
+  distancePtr->clear();
+  distancePtr->resize(referenceSet.n_cols);
+
+  // Create the helper object for the traversal.
+  typedef RangeSearchRules<MetricType, TreeType> RuleType;
+  RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
+      metric);
+
+  if (naive)
+  {
+    // The naive brute-force solution.
+    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+        rules.BaseCase(i, j);
+  }
+  else if (singleMode)
+  {
+    // Create the traverser.
+    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+    // Now have it traverse for each point.
+    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+      traverser.Traverse(i, *referenceTree);
+  }
+  else // Dual-tree recursion.
+  {
+    // Create the traverser.
+    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+    traverser.Traverse(*referenceTree, *referenceTree);
+  }
+
+  Timer::Stop("range_search/computing_neighbors");
+
+  // Do we need to map the reference indices?
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
   {
     neighbors.clear();
-    neighbors.resize(querySet.n_cols);
+    neighbors.resize(referenceSet.n_cols);
     distances.clear();
-    distances.resize(querySet.n_cols);
+    distances.resize(referenceSet.n_cols);
 
     for (size_t i = 0; i < distances.size(); i++)
     {
@@ -319,24 +392,6 @@ void RangeSearch<MetricType, TreeType>::Search(
     delete neighborPtr;
     delete distancePtr;
   }
-  else if (treeOwner && hasQuerySet && singleMode) // Map only references.
-  {
-    neighbors.clear();
-    neighbors.resize(querySet.n_cols);
-
-    // Map indices of neighbors.
-    for (size_t i = 0; i < neighbors.size(); i++)
-    {
-      neighbors[i].resize((*neighborPtr)[i].size());
-      for (size_t j = 0; j < neighbors[i].size(); j++)
-      {
-        neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
-      }
-    }
-
-    // Finished with temporary object.
-    delete neighborPtr;
-  }
 }
 
 template<typename MetricType, typename TreeType>
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index fe91b61..31eba17 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -103,6 +103,7 @@ int main(int argc, char *argv[])
     Log::Fatal << "Invalid range: maximum (" << max << ") must be greater than "
         << "minimum (" << min << ")." << endl;
   }
+  const math::Range r(min, max);
 
   // Sanity check on leaf size.
   if (lsInt < 0)
@@ -110,7 +111,7 @@ int main(int argc, char *argv[])
     Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
         "than or equal to 0." << endl;
   }
-  size_t leafSize = lsInt;
+  const size_t leafSize = lsInt;
 
   // Naive mode overrides single mode.
   if (singleMode && naive)
@@ -118,9 +119,6 @@ int main(int argc, char *argv[])
     Log::Warn << "--single_mode ignored because --naive is present." << endl;
   }
 
-  if (naive)
-    leafSize = referenceData.n_cols;
-
   if (coverTree && naive)
   {
     Log::Warn << "--cover_tree ignored because --naive is present." << endl;
@@ -131,106 +129,92 @@ int main(int argc, char *argv[])
   vector<vector<double> > distances;
 
   // The cover tree implies different types, so we must split this section.
-  if (coverTree)
+  if (naive)
+  {
+    Log::Info << "Performing naive search (no trees)." << endl;
+
+    // Trees don't matter.
+    RangeSearch<> rangeSearch(referenceData, singleMode, naive);
+    rangeSearch.Search(queryData, r, neighbors, distances);
+  }
+  else if (coverTree)
   {
     Log::Info << "Using cover trees." << endl;
 
     // This is significantly simpler than kd-tree construction because the data
     // matrix is not modified.
-    RSCoverType* rangeSearch = NULL;
-    CoverTreeType referenceTree(referenceData);
-    CoverTreeType* queryTree = NULL;
+    RSCoverType rangeSearch(referenceData, singleMode);
 
     if (CLI::GetParam<string>("query_file") == "")
     {
       // Single dataset.
-      rangeSearch = new RSCoverType(&referenceTree, referenceData, singleMode);
+      rangeSearch.Search(r, neighbors, distances);
     }
     else
     {
       // Two datasets.
       const string queryFile = CLI::GetParam<string>("query_file");
       data::Load(queryFile, queryData, true);
-      queryTree = new CoverTreeType(queryData);
 
-      rangeSearch = new RSCoverType(&referenceTree, queryTree, referenceData,
-          queryData, singleMode);
+      // Query tree is automatically built if needed.
+      rangeSearch.Search(queryData, r, neighbors, distances);
     }
-
-    Log::Info << "Trees built." << endl;
-
-    const math::Range r(min, max);
-    rangeSearch->Search(r, neighbors, distances);
-
-    if (queryTree)
-      delete queryTree;
-    delete rangeSearch;
   }
   else
   {
-    // Because we may construct it differently, we need a pointer.
-    RSType* rangeSearch = NULL;
+    typedef BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat> TreeType;
 
-    // Mappings for when we build the tree.
-    vector<size_t> oldFromNewRefs;
-
-    // Build trees by hand, so we can save memory: if we pass a tree to
-    // NeighborSearch, it does not copy the matrix.
+    // Track mappings.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
-
-    BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat>
-        refTree(referenceData, oldFromNewRefs, leafSize);
-    BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat>*
-        queryTree = NULL; // Empty for now.
-
+    vector<size_t> oldFromNewRefs;
+    vector<size_t> oldFromNewQueries; // Not used yet.
+    TreeType refTree(referenceData, oldFromNewRefs, leafSize);
     Timer::Stop("tree_building");
 
-    vector<size_t> oldFromNewQueries;
+    // Collect the results in these vectors before remapping.
+    vector<vector<double> > distancesOut;
+    vector<vector<size_t> > neighborsOut;
+
+    RSType rangeSearch(&refTree, singleMode);
 
     if (CLI::GetParam<string>("query_file") != "")
     {
       const string queryFile = CLI::GetParam<string>("query_file");
       data::Load(queryFile, queryData, true);
 
-      if (naive && leafSize < queryData.n_cols)
-        leafSize = queryData.n_cols;
-
       Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
 
-      Log::Info << "Building query tree..." << endl;
-
-      // Build trees by hand, so we can save memory: if we pass a tree to
-      // NeighborSearch, it does not copy the matrix.
-      Timer::Start("tree_building");
-
-      queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-          RangeSearchStat>(queryData, oldFromNewQueries, leafSize);
+      if (singleMode)
+      {
+        Log::Info << "Computing neighbors within range [" << min << ", " << max
+            << "]." << endl;
+        rangeSearch.Search(queryData, r, neighborsOut, distancesOut);
+      }
+      else
+      {
+        Log::Info << "Building query tree..." << endl;
 
-      Timer::Stop("tree_building");
+        // Build trees by hand, so we can save memory: if we pass a tree to
+        // NeighborSearch, it does not copy the matrix.
+        Timer::Start("tree_building");
+        TreeType queryTree(queryData, oldFromNewQueries, leafSize);
+        Timer::Stop("tree_building");
 
-      rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
-          singleMode);
+        Log::Info << "Tree built." << endl;
 
-      Log::Info << "Tree built." << endl;
+        Log::Info << "Computing neighbors within range [" << min << ", " << max
+            << "]." << endl;
+        rangeSearch.Search(&queryTree, r, neighborsOut, distancesOut);
+      }
     }
     else
     {
-      rangeSearch = new RSType(&refTree, referenceData, singleMode);
-
-      Log::Info << "Trees built." << endl;
+      Log::Info << "Computing neighbors within range [" << min << ", " << max
+          << "]." << endl;
+      rangeSearch.Search(r, neighborsOut, distancesOut);
     }
 
-    Log::Info << "Computing neighbors within range [" << min << ", " << max
-        << "]." << endl;
-
-    // Collect the results in these vectors before remapping.
-    vector<vector<double> > distancesOut;
-    vector<vector<size_t> > neighborsOut;
-
-    const math::Range r(min, max);
-    rangeSearch->Search(r, neighborsOut, distancesOut);
-
     Log::Info << "Neighbors computed." << endl;
 
     // We have to map back to the original indices from before the tree
@@ -272,11 +256,6 @@ int main(int argc, char *argv[])
         }
       }
     }
-
-    // Clean up.
-    if (queryTree)
-      delete queryTree;
-    delete rangeSearch;
   }
 
   // Save output.  We have to do this by hand.
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index b497d66..f1cd60d 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -21,9 +21,9 @@ BOOST_AUTO_TEST_SUITE(RangeSearchTest);
 
 // Get our results into a sorted format, so we can actually then test for
 // correctness.
-void SortResults(const vector<vector<size_t> >& neighbors,
-                 const vector<vector<double> >& distances,
-                 vector<vector<pair<double, size_t> > >& output)
+void SortResults(const vector<vector<size_t>>& neighbors,
+                 const vector<vector<double>>& distances,
+                 vector<vector<pair<double, size_t>>>& output)
 {
   output.resize(neighbors.size());
   for (size_t i = 0; i < neighbors.size(); i++)
@@ -90,20 +90,20 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
         rs = new RangeSearch<>(dataMutable, true);
         break;
       case 1: // Use the single-tree method.
-        rs = new RangeSearch<>(tree, dataMutable, true);
+        rs = new RangeSearch<>(tree, true);
         break;
       case 2: // Use the dual-tree method.
-        rs = new RangeSearch<>(tree, dataMutable);
+        rs = new RangeSearch<>(tree);
         break;
     }
 
     // Now perform the first calculation.  Points within 0.50.
-    vector<vector<size_t> > neighbors;
-    vector<vector<double> > distances;
+    vector<vector<size_t>> neighbors;
+    vector<vector<double>> distances;
     rs->Search(Range(0.0, sqrt(0.5)), neighbors, distances);
 
     // Now the exhaustive check for correctness.  This will be long.
-    vector<vector<pair<double, size_t> > > sortedOutput;
+    vector<vector<pair<double, size_t>>> sortedOutput;
     SortResults(neighbors, distances, sortedOutput);
 
     BOOST_REQUIRE(sortedOutput[newFromOld[0]].size() == 4);
@@ -460,20 +460,20 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
   arma::mat naiveQuery(dataForTree);
   arma::mat naiveReferences(dataForTree);
 
-  RangeSearch<> rs(dualQuery, dualReferences);
+  RangeSearch<> rs(dualReferences);
 
-  RangeSearch<> naive(naiveQuery, naiveReferences, true);
+  RangeSearch<> naive(naiveReferences, true);
 
-  vector<vector<size_t> > neighborsTree;
-  vector<vector<double> > distancesTree;
-  rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
-  vector<vector<pair<double, size_t> > > sortedTree;
+  vector<vector<size_t>> neighborsTree;
+  vector<vector<double>> distancesTree;
+  rs.Search(dualQuery, Range(0.25, 1.05), neighborsTree, distancesTree);
+  vector<vector<pair<double, size_t>>> sortedTree;
   SortResults(neighborsTree, distancesTree, sortedTree);
 
-  vector<vector<size_t> > neighborsNaive;
-  vector<vector<double> > distancesNaive;
-  naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
-  vector<vector<pair<double, size_t> > > sortedNaive;
+  vector<vector<size_t>> neighborsNaive;
+  vector<vector<double>> distancesNaive;
+  naive.Search(naiveQuery, Range(0.25, 1.05), neighborsNaive, distancesNaive);
+  vector<vector<pair<double, size_t>>> sortedNaive;
   SortResults(neighborsNaive, distancesNaive, sortedNaive);
 
   for (size_t i = 0; i < sortedTree.size(); i++)
@@ -513,16 +513,16 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
   // Set naive mode.
   RangeSearch<> naive(naiveQuery, true);
 
-  vector<vector<size_t> > neighborsTree;
-  vector<vector<double> > distancesTree;
+  vector<vector<size_t>> neighborsTree;
+  vector<vector<double>> distancesTree;
   rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
-  vector<vector<pair<double, size_t> > > sortedTree;
+  vector<vector<pair<double, size_t>>> sortedTree;
   SortResults(neighborsTree, distancesTree, sortedTree);
 
-  vector<vector<size_t> > neighborsNaive;
-  vector<vector<double> > distancesNaive;
+  vector<vector<size_t>> neighborsNaive;
+  vector<vector<double>> distancesNaive;
   naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
-  vector<vector<pair<double, size_t> > > sortedNaive;
+  vector<vector<pair<double, size_t>>> sortedNaive;
   SortResults(neighborsNaive, distancesNaive, sortedNaive);
 
   for (size_t i = 0; i < sortedTree.size(); i++)
@@ -562,16 +562,16 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
   // Set up computation for naive mode.
   RangeSearch<> naive(naiveQuery, true);
 
-  vector<vector<size_t> > neighborsSingle;
-  vector<vector<double> > distancesSingle;
+  vector<vector<size_t>> neighborsSingle;
+  vector<vector<double>> distancesSingle;
   single.Search(Range(0.25, 1.05), neighborsSingle, distancesSingle);
-  vector<vector<pair<double, size_t> > > sortedTree;
+  vector<vector<pair<double, size_t>>> sortedTree;
   SortResults(neighborsSingle, distancesSingle, sortedTree);
 
-  vector<vector<size_t> > neighborsNaive;
-  vector<vector<double> > distancesNaive;
+  vector<vector<size_t>> neighborsNaive;
+  vector<vector<double>> distancesNaive;
   naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
-  vector<vector<pair<double, size_t> > > sortedNaive;
+  vector<vector<pair<double, size_t>>> sortedNaive;
   SortResults(neighborsNaive, distancesNaive, sortedNaive);
 
   for (size_t i = 0; i < sortedTree.size(); i++)
@@ -600,8 +600,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
   typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
       RangeSearchStat> CoverTreeType;
   CoverTreeType tree(data);
-  RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree,
-      data);
+  RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -631,12 +630,12 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for cover tree search.
-    vector<vector<size_t> > coverNeighbors;
-    vector<vector<double> > coverDistances;
+    vector<vector<size_t>> coverNeighbors;
+    vector<vector<double>> coverDistances;
 
     // Clean the tree statistics.
     CleanTree(tree);
@@ -646,8 +645,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
     coversearch.Search(range, coverNeighbors, coverDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > coverSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> coverSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(coverNeighbors, coverDistances, coverSorted);
 
@@ -680,16 +679,16 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
   typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
       RangeSearchStat> CoverTreeType;
   CoverTreeType tree(data);
-  CoverTreeType queryTree(queries);
+  CoverTreeType* queryTree = new CoverTreeType(queries);
   RangeSearch<metric::EuclideanDistance, CoverTreeType>
-      coversearch(&tree, &queryTree, data, queries);
+      coversearch(&tree);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
   {
     // Set up kd-tree range search.  We don't have an easy way to rebuild the
     // tree, so we'll just reinstantiate it here each loop time.
-    RangeSearch<> kdsearch(data, queries);
+    RangeSearch<> kdsearch(data);
 
     Range range;
     switch (r)
@@ -713,24 +712,25 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for cover tree search.
-    vector<vector<size_t> > coverNeighbors;
-    vector<vector<double> > coverDistances;
+    vector<vector<size_t>> coverNeighbors;
+    vector<vector<double>> coverDistances;
 
     // Clean the trees.
     CleanTree(tree);
-    CleanTree(queryTree);
+    delete queryTree;
+    queryTree = new CoverTreeType(queries);
 
     // Run the searches.
-    coversearch.Search(range, coverNeighbors, coverDistances);
-    kdsearch.Search(range, kdNeighbors, kdDistances);
+    coversearch.Search(queryTree, range, coverNeighbors, coverDistances);
+    kdsearch.Search(queries, range, kdNeighbors, kdDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > coverSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> coverSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(coverNeighbors, coverDistances, coverSorted);
 
@@ -746,6 +746,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
       BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
     }
   }
+
+  delete queryTree;
 }
 
 /**
@@ -760,8 +762,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
   typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
       RangeSearchStat> CoverTreeType;
   CoverTreeType tree(data);
-  RangeSearch<metric::EuclideanDistance, CoverTreeType>
-      coversearch(&tree, data, true);
+  RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree,
+      true);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -791,12 +793,12 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for cover tree search.
-    vector<vector<size_t> > coverNeighbors;
-    vector<vector<double> > coverDistances;
+    vector<vector<size_t>> coverNeighbors;
+    vector<vector<double>> coverDistances;
 
     // Clean the tree statistics.
     CleanTree(tree);
@@ -806,8 +808,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
     coversearch.Search(range, coverNeighbors, coverDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > coverSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> coverSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(coverNeighbors, coverDistances, coverSorted);
 
@@ -837,7 +839,7 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
   typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
   TreeType tree(data);
   RangeSearch<metric::EuclideanDistance, TreeType>
-      ballsearch(&tree, data, true);
+      ballsearch(&tree, true);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -867,12 +869,12 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for ball tree search.
-    vector<vector<size_t> > ballNeighbors;
-    vector<vector<double> > ballDistances;
+    vector<vector<size_t>> ballNeighbors;
+    vector<vector<double>> ballDistances;
 
     // Clean the tree statistics.
     CleanTree(tree);
@@ -882,8 +884,8 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
     ballsearch.Search(range, ballNeighbors, ballDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > ballSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> ballSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(ballNeighbors, ballDistances, ballSorted);
 
@@ -913,7 +915,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
   // Set up ball tree range search.
   typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
   TreeType tree(data);
-  RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree, data);
+  RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -943,12 +945,12 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for ball tree search.
-    vector<vector<size_t> > ballNeighbors;
-    vector<vector<double> > ballDistances;
+    vector<vector<size_t>> ballNeighbors;
+    vector<vector<double>> ballDistances;
 
     // Clean the tree statistics.
     CleanTree(tree);
@@ -958,8 +960,8 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
     ballsearch.Search(range, ballNeighbors, ballDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > ballSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> ballSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(ballNeighbors, ballDistances, ballSorted);
 
@@ -993,15 +995,14 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
   typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
   TreeType tree(data);
   TreeType queryTree(queries);
-  RangeSearch<metric::EuclideanDistance, TreeType>
-      ballsearch(&tree, &queryTree, data, queries);
+  RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
   {
     // Set up kd-tree range search.  We don't have an easy way to rebuild the
     // tree, so we'll just reinstantiate it here each loop time.
-    RangeSearch<> kdsearch(data, queries);
+    RangeSearch<> kdsearch(data);
 
     Range range;
     switch (r)
@@ -1025,24 +1026,24 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
     }
 
     // Results for kd-tree search.
-    vector<vector<size_t> > kdNeighbors;
-    vector<vector<double> > kdDistances;
+    vector<vector<size_t>> kdNeighbors;
+    vector<vector<double>> kdDistances;
 
     // Results for ball tree search.
-    vector<vector<size_t> > ballNeighbors;
-    vector<vector<double> > ballDistances;
+    vector<vector<size_t>> ballNeighbors;
+    vector<vector<double>> ballDistances;
 
     // Clean the trees.
     CleanTree(tree);
     CleanTree(queryTree);
 
     // Run the searches.
-    ballsearch.Search(range, ballNeighbors, ballDistances);
-    kdsearch.Search(range, kdNeighbors, kdDistances);
+    ballsearch.Search(&queryTree, range, ballNeighbors, ballDistances);
+    kdsearch.Search(queries, range, kdNeighbors, kdDistances);
 
     // Sort before comparison.
-    vector<vector<pair<double, size_t> > > kdSorted;
-    vector<vector<pair<double, size_t> > > ballSorted;
+    vector<vector<pair<double, size_t>>> kdSorted;
+    vector<vector<pair<double, size_t>>> ballSorted;
     SortResults(kdNeighbors, kdDistances, kdSorted);
     SortResults(ballNeighbors, ballDistances, ballSorted);
 



More information about the mlpack-git mailing list