[mlpack-git] master: Refactor NeighborSearch like RangeSearch. Now Search() takes the query set. (80650d4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 22 16:32:27 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8f85309ae9be40e819b301b39c9a940aa28f3bb2...57d0567dddff01feea73b348f38cc040dc3cf8e3

>---------------------------------------------------------------

commit 80650d46316345367ec9b88fbd45df27beb854ee
Author: ryan <ryan at ratml.org>
Date:   Wed Apr 22 14:18:21 2015 -0400

    Refactor NeighborSearch like RangeSearch.
    Now Search() takes the query set.


>---------------------------------------------------------------

80650d46316345367ec9b88fbd45df27beb854ee
 .../methods/neighbor_search/neighbor_search.hpp    | 154 +++-----
 .../neighbor_search/neighbor_search_impl.hpp       | 405 +++++++++++----------
 .../neighbor_search/neighbor_search_rules.hpp      |   6 +-
 .../neighbor_search/neighbor_search_rules_impl.hpp |   6 +-
 4 files changed, 284 insertions(+), 287 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 9ae8b6a..18111a3 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -50,66 +50,35 @@ class NeighborSearch
 {
  public:
   /**
-   * Initialize the NeighborSearch object, passing both a query and reference
-   * dataset.  Optionally, perform the computation in naive mode or single-tree
-   * mode, and set the leaf size used for tree-building.  An initialized
-   * distance metric can be given, for cases where the metric has internal data
-   * (i.e. the distance::MahalanobisDistance class).
+   * Initialize the NeighborSearch object, passing a reference dataset (this is
+   * the dataset which is searched).  Optionally, perform the computation in
+   * naive mode or single-tree mode.  An initialized distance metric can be
+   * given, for cases where the metric has internal data (i.e. the
+   * distance::MahalanobisDistance class).
    *
    * This method will copy the matrices to internal copies, which are rearranged
    * during tree-building.  You can avoid this extra copy by pre-constructing
    * the trees and passing them using a diferent constructor.
    *
    * @param referenceSet Set of reference points.
-   * @param querySet Set of query points.
    * @param naive If true, O(n^2) naive search will be used (as opposed to
    *      dual-tree search).  This overrides singleMode (if it is set to true).
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
    * @param metric An optional instance of the MetricType class.
    */
   NeighborSearch(const typename TreeType::Mat& referenceSet,
-                 const typename TreeType::Mat& querySet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object, passing only one dataset, which is
-   * used as both the query and the reference dataset.  Optionally, perform the
-   * computation in naive mode or single-tree mode, and set the leaf size used
-   * for tree-building.  An initialized distance metric can be given, for cases
-   * where the metric has internal data (i.e. the distance::MahalanobisDistance
-   * class).
-   *
-   * If naive mode is being used and a pre-built tree is given, it may not work:
-   * naive mode operates by building a one-node tree (the root node holds all
-   * the points).  If that condition is not satisfied with the pre-built tree,
-   * then naive mode will not work.
-   *
-   * @param referenceSet Set of reference points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param leafSize Leaf size for tree construction (ignored if tree is given).
-   * @param metric An optional instance of the MetricType class.
-   */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
-                 const bool naive = false,
-                 const bool singleMode = false,
-                 const MetricType metric = MetricType());
-
-  /**
-   * Initialize the NeighborSearch object with the given datasets and
-   * pre-constructed trees.  It is assumed that the points in referenceSet and
-   * querySet correspond to the points in referenceTree and queryTree,
-   * respectively.  Optionally, choose to use single-tree mode.  Naive mode is
-   * not available as an option for this constructor; instead, to run naive
-   * computation, construct a tree with all of the points in one leaf (i.e.
-   * leafSize = number of points).  Additionally, an instantiated distance
-   * metric can be given, for cases where the distance metric holds data.
+   * Initialize the NeighborSearch object with the given pre-constructed
+   * reference tree (this is the tree built on the points that will be
+   * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
+   * available as an option for this constructor.  Additionally, an instantiated
+   * distance metric can be given, for cases where the distance metric holds
+   * data.
    *
    * There is no copying of the data matrices in this constructor (because
    * tree-building is not necessary), so this is the constructor to use when
@@ -123,73 +92,71 @@ class NeighborSearch
    * @endnote
    *
    * @param referenceTree Pre-built tree for reference points.
-   * @param queryTree Pre-built tree for query points.
    * @param referenceSet Set of reference points corresponding to referenceTree.
-   * @param querySet Set of query points corresponding to queryTree.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
   NeighborSearch(TreeType* referenceTree,
-                 TreeType* queryTree,
-                 const typename TreeType::Mat& referenceSet,
-                 const typename TreeType::Mat& querySet,
                  const bool singleMode = false,
                  const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object with the given reference dataset and
-   * pre-constructed tree.  It is assumed that the points in referenceSet
-   * correspond to the points in referenceTree.  Optionally, choose to use
-   * single-tree mode.  Naive mode is not available as an option for this
-   * constructor; instead, to run naive computation, construct a tree with all
-   * the points in one leaf (i.e. leafSize = number of points).  Additionally,
-   * an instantiated distance metric can be given, for the case where the
-   * distance metric holds data.
-   *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
-   *
-   * @note
-   * Because tree-building (at least with BinarySpaceTree) modifies the ordering
-   * of a matrix, be sure you pass the modified matrix to this object!  In
-   * addition, mapping the points of the matrix back to their original indices
-   * is not done when this constructor is used.
-   * @endnote
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param referenceSet Set of reference points corresponding to referenceTree.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param metric Instantiated distance metric.
-   */
-  NeighborSearch(TreeType* referenceTree,
-                 const typename TreeType::Mat& referenceSet,
-                 const bool singleMode = false,
-                 const MetricType metric = MetricType());
-
-
-  /**
    * Delete the NeighborSearch object. The tree is the only member we are
    * responsible for deleting.  The others will take care of themselves.
    */
   ~NeighborSearch();
 
   /**
-   * Compute the nearest neighbors and store the output in the given matrices.
-   * The matrices will be set to the size of n columns by k rows, where n is the
+   * For each point in the query set, compute the nearest neighbors and store
+   * the output in the given matrices.  The matrices will be set to the size of
+   * n columns by k rows, where n is the number of points in the query dataset
+   * and k is the number of neighbors being searched for.
+   *
+   * @param querySet Set of query points (can be just one point).
+   * @param k Number of neighbors to search for.
+   * @param neighbors Matrix storing lists of neighbors for each query point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *     point.
+   */
+  void Search(const typename TreeType::Mat& querySet,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  /**
+   * Given a pre-built query tree, search for the nearest neighbors of each
+   * point in the query tree, storing the output in the given matrices.  The
+   * matrices will be set to the size of n columns by k rows, where n is the
    * number of points in the query dataset and k is the number of neighbors
    * being searched for.
    *
+   * @param queryTree Tree built on query points.
    * @param k Number of neighbors to search for.
-   * @param resultingNeighbors Matrix storing lists of neighbors for each query
-   *     point.
+   * @param neighbors Matrix storing lists of neighbors for each query point.
    * @param distances Matrix storing distances of neighbors for each query
-   *     point.
+   *      point.
+   */
+  void Search(TreeType* queryTree,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  /**
+   * Search for the nearest neighbors of every point in the reference set.  This
+   * is basically equivalent to calling any other overload of Search() with the
+   * reference set as the query set; so, this lets you do all-k-nearest-neighbors
+   * search.  The results are stored in the given matrices.  The matrices will
+   * be set to the size of n columns by k rows, where n is the number of points
+   * in the query dataset and k is the number of neighbors being searched for.
+   *
+   * @param k Number of neighbors to search for.
+   * @param neighbors Matrix storing lists of neighbors for each query point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *      point.
    */
   void Search(const size_t k,
-              arma::Mat<size_t>& resultingNeighbors,
+              arma::Mat<size_t>& neighbors,
               arma::mat& distances);
 
   //! Returns a string representation of this object.
@@ -210,18 +177,12 @@ class NeighborSearch
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
   typename TreeType::Mat referenceCopy;
-  //! Copy of query dataset (if we need it, because tree building modifies it).
-  typename TreeType::Mat queryCopy;
-
   //! Reference dataset.
   const typename TreeType::Mat& referenceSet;
-  //! Query dataset (may not be given).
-  const typename TreeType::Mat& querySet;
-
   //! Pointer to the root of the reference tree.
   TreeType* referenceTree;
-  //! Pointer to the root of the query tree (might not exist).
-  TreeType* queryTree;
+  //! Permutations of reference points during tree building.
+  std::vector<size_t> oldFromNewReferences;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
@@ -236,11 +197,6 @@ class NeighborSearch
   //! Instantiation of metric.
   MetricType metric;
 
-  //! Permutations of reference points during tree building.
-  std::vector<size_t> oldFromNewReferences;
-  //! Permutations of query points during tree building.
-  std::vector<size_t> oldFromNewQueries;
-
   //! The total number of base cases.
   size_t baseCases;
   //! The total number of scores (applicable for non-naive search).
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 06fbcb2..5a1c3c6 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -43,96 +43,34 @@ TreeType* BuildTree(
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
 NeighborSearch(const typename TreeType::Mat& referenceSetIn,
-               const typename TreeType::Mat& querySetIn,
                const bool naive,
                const bool singleMode,
                const MetricType metric) :
-    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
-        : querySetIn),
+    referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+        ? referenceCopy : referenceSetIn),
     referenceTree(NULL),
-    queryTree(NULL),
     treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
-    hasQuerySet(true),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
     baseCases(0),
     scores(0)
 {
-  // C++11 will allow us to call out to other constructors so we can avoid this
-  // copypasta problem.
-
-  // We'll time tree building, but only if we are building trees.
+  // Build the tree.
   Timer::Start("tree_building");
 
-  // Copy the datasets, if they will be modified during tree building.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-  {
-    referenceCopy = referenceSetIn;
-    queryCopy = querySetIn;
-  }
-
-  // If not in naive mode, then we need to build trees.
   if (!naive)
   {
-    // The const_cast is safe; if RearrangesDataset == false, then it'll be
-    // casted back to const anyway, and if not, referenceSet points to
-    // referenceCopy, which isn't const.
-    referenceTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(referenceSet),
-        oldFromNewReferences);
-
-    if (!singleMode)
-      queryTree = BuildTree<TreeType>(
-          const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
-  }
-
-  // Stop the timer we started above (if we need to).
-  Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSetIn,
-               const bool naive,
-               const bool singleMode,
-               const MetricType metric) :
-    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
-        : referenceSetIn),
-    referenceTree(NULL),
-    queryTree(NULL),
-    treeOwner(!naive), // If naive, then we are not building any trees.
-    hasQuerySet(false),
-    naive(naive),
-    singleMode(!naive && singleMode), // No single mode if naive.
-    metric(metric),
-    baseCases(0),
-    scores(0)
-{
-  // We'll time tree building, but only if we are building trees.
-  Timer::Start("tree_building");
-
-  // Copy the dataset, if it will be modified during tree building.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-    referenceCopy = referenceSetIn;
+    // Copy the dataset, if it will be modified during tree building.
+    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      referenceCopy = referenceSetIn;
 
-  // If not in naive mode, then we may need to construct trees.
-  if (!naive)
-  {
     // The const_cast is safe; if RearrangesDataset == false, then it'll be
     // casted back to const anyway, and if not, referenceSet points to
     // referenceCopy, which isn't const.
     referenceTree = BuildTree<TreeType>(
         const_cast<typename TreeType::Mat&>(referenceSet),
         oldFromNewReferences);
-
-    if (!singleMode)
-      queryTree = new TreeType(*referenceTree);
   }
 
   // Stop the timer we started above.
@@ -143,17 +81,11 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
     TreeType* referenceTree,
-    TreeType* queryTree,
-    const typename TreeType::Mat& referenceSet,
-    const typename TreeType::Mat& querySet,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet(referenceSet),
-    querySet(querySet),
+    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
-    queryTree(queryTree),
     treeOwner(false),
-    hasQuerySet(true),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -163,53 +95,12 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
   // Nothing else to initialize.
 }
 
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
-    TreeType* referenceTree,
-    const typename TreeType::Mat& referenceSet,
-    const bool singleMode,
-    const MetricType metric) :
-    referenceSet(referenceSet),
-    querySet(referenceSet),
-    referenceTree(referenceTree),
-    queryTree(NULL),
-    treeOwner(false),
-    hasQuerySet(false), // In this case we will own a tree, if singleMode.
-    naive(false),
-    singleMode(singleMode),
-    metric(metric),
-    baseCases(0),
-    scores(0)
-{
-  Timer::Start("tree_building");
-
-  // The query tree cannot be the same as the reference tree.
-  if (referenceTree && !singleMode)
-    queryTree = new TreeType(*referenceTree);
-
-  Timer::Stop("tree_building");
-}
-
-/**
- * The tree is the only member we may be responsible for deleting.  The others
- * will take care of themselves.
- */
+// Clean memory.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
 {
-  if (treeOwner)
-  {
-    if (referenceTree)
-      delete referenceTree;
-    if (queryTree)
-      delete queryTree;
-  }
-  else if (!treeOwner && !hasQuerySet && !(singleMode || naive))
-  {
-    // We replicated the reference tree to create a query tree.
-    delete queryTree;
-  }
+  if (treeOwner && referenceTree)
+    delete referenceTree;
 }
 
 /**
@@ -218,23 +109,27 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
 void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+    const typename TreeType::Mat& querySet,
     const size_t k,
-    arma::Mat<size_t>& resultingNeighbors,
+    arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
   Timer::Start("computing_neighbors");
 
+  // This will hold mappings for query points, if necessary.
+  std::vector<size_t> oldFromNewQueries;
+
   // If we have built the trees ourselves, then we will have to map all the
   // indices back to their original indices when this computation is finished.
   // To avoid an extra copy, we will store the neighbors and distances in a
   // separate matrix.
-  arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+  arma::Mat<size_t>* neighborPtr = &neighbors;
   arma::mat* distancePtr = &distances;
 
   // Mapping is only necessary if the tree rearranges points.
   if (tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    if (treeOwner && !(singleMode && hasQuerySet))
+    if (!singleMode && !naive)
       distancePtr = new arma::mat; // Query indices need to be mapped.
 
     if (treeOwner)
@@ -247,9 +142,20 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
+  // If we will be building a tree and it will modify the query set, make a copy
+  // of the dataset.
+  typename TreeType::Mat queryCopy;
+  const bool needsCopy = (!naive && !singleMode &&
+      tree::TreeTraits<TreeType>::RearrangesDataset);
+  if (needsCopy)
+    queryCopy = querySet;
+
+  const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
+      querySet;
+
   // Create the helper object for the tree traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
-  RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+  RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr, metric);
 
   if (naive)
   {
@@ -262,10 +168,6 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
   else if (singleMode)
   {
-    // The search doesn't work if the root node is also a leaf node.
-    // If this is the case, it is suggested that you use the naive method.
-    Log::Assert(!(referenceTree->IsLeaf()));
-
     // Create the traverser.
     typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
 
@@ -281,6 +183,14 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
   else // Dual-tree recursion.
   {
+    // Build the query tree.
+    Timer::Stop("computing_neighbors");
+    Timer::Start("tree_building");
+    TreeType* queryTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+    Timer::Stop("tree_building");
+    Timer::Start("computing_neighbors");
+
     // Create the traverser.
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
@@ -295,92 +205,217 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
 
   Timer::Stop("computing_neighbors");
 
-  // Now, do we need to do mapping of indices?
-  if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
-  {
-    // No mapping needed.  We are done.
-    return;
-  }
-  else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
+  // Map points back to original indices, if necessary.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    // Set size of output matrices correctly.
-    resultingNeighbors.set_size(k, querySet.n_cols);
-    distances.set_size(k, querySet.n_cols);
-
-    for (size_t i = 0; i < distances.n_cols; i++)
+    if (!singleMode && !naive && treeOwner)
     {
-      // Map distances (copy a column).
-      distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+      // We must map both query and reference indices.
+      neighbors.set_size(k, querySet.n_cols);
+      distances.set_size(k, querySet.n_cols);
 
-      // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; j++)
+      for (size_t i = 0; i < distances.n_cols; i++)
       {
-        resultingNeighbors(j, oldFromNewQueries[i]) =
-            oldFromNewReferences[(*neighborPtr)(j, i)];
+        // Map distances (copy a column).
+        distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+        // Map indices of neighbors.
+        for (size_t j = 0; j < distances.n_rows; j++)
+        {
+          neighbors(j, oldFromNewQueries[i]) =
+              oldFromNewReferences[(*neighborPtr)(j, i)];
+        }
       }
+
+      // Finished with temporary matrices.
+      delete neighborPtr;
+      delete distancePtr;
     }
+    else if (!singleMode && !naive)
+    {
+      // We must map query indices only.
+      neighbors.set_size(k, querySet.n_cols);
+      distances.set_size(k, querySet.n_cols);
 
-    // Finished with temporary matrices.
-    delete neighborPtr;
-    delete distancePtr;
-  }
-  else if (treeOwner && !hasQuerySet)
-  {
-    resultingNeighbors.set_size(k, querySet.n_cols);
-    distances.set_size(k, querySet.n_cols);
+      for (size_t i = 0; i < distances.n_cols; ++i)
+      {
+        // Map distances (copy a column).
+        const size_t queryMapping = oldFromNewQueries[i];
+        distances.col(queryMapping) = distancePtr->col(i);
+        neighbors.col(queryMapping) = neighborPtr->col(i);
+      }
 
-    for (size_t i = 0; i < distances.n_cols; i++)
+      // Finished with temporary matrices.
+      delete neighborPtr;
+      delete distancePtr;
+    }
+    else if (treeOwner)
     {
-      // Map distances (copy a column).
-      distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+      // We must map reference indices only.
+      neighbors.set_size(k, querySet.n_cols);
 
       // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; j++)
-      {
-        resultingNeighbors(j, oldFromNewReferences[i]) =
-            oldFromNewReferences[(*neighborPtr)(j, i)];
-      }
+      for (size_t i = 0; i < neighbors.n_cols; i++)
+        for (size_t j = 0; j < neighbors.n_rows; j++)
+          neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+
+      // Finished with temporary matrix.
+      delete neighborPtr;
     }
+  }
+} // Search()
 
-    // Finished with temporary matrices.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+    TreeType* queryTree,
+    const size_t k,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
+{
+  Timer::Start("computing_neighbors");
+
+  // Get a reference to the query set.
+  const typename TreeType::Mat& querySet = queryTree->Dataset();
+
+  // Make sure we are in dual-tree mode.
+  if (singleMode || naive)
+    throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
+        "query tree when naive or singleMode are set to true");
+
+  // We won't need to map query indices, but will we need to map distances?
+  arma::Mat<size_t>* neighborPtr = &neighbors;
+
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+    neighborPtr = new arma::Mat<size_t>;
+
+  neighborPtr->set_size(k, querySet.n_cols);
+  neighborPtr->fill(size_t() - 1);
+  distances.set_size(k, querySet.n_cols);
+  distances.fill(SortPolicy::WorstDistance());
+
+  // Create the helper object for the traversal.
+  typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
+
+  // Create the traverser.
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+  traverser.Traverse(*queryTree, *referenceTree);
+
+  Timer::Stop("computing_neighbors");
+
+  // Do we need to map indices?
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    // We must map reference indices only.
+    neighbors.set_size(k, querySet.n_cols);
+
+    // Map indices of neighbors.
+    for (size_t i = 0; i < neighbors.n_cols; i++)
+      for (size_t j = 0; j < neighbors.n_rows; j++)
+        neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+
+    // Finished with temporary matrix.
     delete neighborPtr;
-    delete distancePtr;
   }
-  else if (treeOwner && hasQuerySet && singleMode) // Map only references.
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+    const size_t k,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
+{
+  Timer::Start("computing_neighbors");
+
+  arma::Mat<size_t>* neighborPtr = &neighbors;
+  arma::mat* distancePtr = &distances;
+
+  if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
   {
-    // Set size of neighbor indices matrix correctly.
-    resultingNeighbors.set_size(k, querySet.n_cols);
+    // We will always need to rearrange in this case.
+    distancePtr = new arma::mat;
+    neighborPtr = new arma::Mat<size_t>;
+  }
 
-    // Map indices of neighbors.
-    for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+  // Initialize results.
+  neighborPtr->set_size(k, referenceSet.n_cols);
+  neighborPtr->fill(size_t() - 1);
+  distancePtr->set_size(k, referenceSet.n_cols);
+  distancePtr->fill(SortPolicy::WorstDistance());
+
+  // Create the helper object for the traversal.
+  typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+      metric, true /* don't return the same point as nearest neighbor */);
+
+  if (naive)
+  {
+    // The naive brute-force solution.
+    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+        rules.BaseCase(i, j);
+
+    baseCases += referenceSet.n_cols * referenceSet.n_cols;
+  }
+  else if (singleMode)
+  {
+    // Create the traverser.
+    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+    // Now have it traverse for each point.
+    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+      traverser.Traverse(i, *referenceTree);
+
+    Log::Info << rules.Scores() << " node combinations were scored.\n";
+    Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+  }
+  else
+  {
+    // Create the traverser.
+    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+    traverser.Traverse(*referenceTree, *referenceTree);
+
+    Log::Info << rules.Scores() << " node combinations were scored.\n";
+    Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+  }
+
+  Timer::Stop("computing_neighbors");
+
+  // Do we need to map the reference indices?
+  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    neighbors.set_size(k, referenceSet.n_cols);
+    distances.set_size(k, referenceSet.n_cols);
+
+    for (size_t i = 0; i < distances.n_cols; ++i)
     {
-      for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
-      {
-        resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
-      }
+      // Map distances (copy a column).
+      const size_t refMapping = oldFromNewReferences[i];
+      distances.col(refMapping) = distancePtr->col(i);
+
+      // Map each neighbor's index.
+      for (size_t j = 0; j < distances.n_rows; ++j)
+        neighbors(j, refMapping) = oldFromNewReferences[(*neighborPtr)(j, i)];
     }
 
-    // Finished with temporary matrix.
+    // Finished with temporary matrices.
     delete neighborPtr;
+    delete distancePtr;
   }
-} // Search
-
+}
 
-//Return a String of the Object.
+// Return a String of the Object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 std::string NeighborSearch<SortPolicy, MetricType, TreeType>::ToString() const
 {
   std::ostringstream convert;
   convert << "NeighborSearch [" << this << "]" << std::endl;
-  convert << "  Reference Set: " << referenceSet.n_rows << "x" ;
-  convert <<  referenceSet.n_cols << std::endl;
-  if (&referenceSet != &querySet)
-    convert << "  QuerySet: " << querySet.n_rows << "x" << querySet.n_cols
-        << std::endl;
-  convert << "  Reference Tree: " << referenceTree << std::endl;
-  if (&referenceTree != &queryTree)
-    convert << "  QueryTree: " << queryTree << std::endl;
-  convert << "  Tree Owner: " << treeOwner << std::endl;
+  convert << "  Reference set: " << referenceSet.n_rows << "x" ;
+  convert << referenceSet.n_cols << std::endl;
+  if (referenceTree)
+    convert << "  Reference tree: " << referenceTree << std::endl;
+  convert << "  Tree owner: " << treeOwner << std::endl;
   convert << "  Naive: " << naive << std::endl;
   convert << "  Metric: " << std::endl;
   convert << mlpack::util::Indent(metric.ToString(),2);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index e09df0c..0c27690 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -21,7 +21,8 @@ class NeighborSearchRules
                       const typename TreeType::Mat& querySet,
                       arma::Mat<size_t>& neighbors,
                       arma::mat& distances,
-                      MetricType& metric);
+                      MetricType& metric,
+                      const bool sameSet = false);
   /**
    * Get the distance from the query point to the reference point.
    * This will update the "neighbor" matrix with the new point if appropriate
@@ -116,6 +117,9 @@ class NeighborSearchRules
   //! The instantiated metric.
   MetricType& metric;
 
+  //! Denotes whether or not the reference and query sets are the same.
+  bool sameSet;
+
   //! The last query point BaseCase() was called with.
   size_t lastQueryIndex;
   //! The last reference point BaseCase() was called with.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 131c132..50acfc2 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -19,12 +19,14 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     const typename TreeType::Mat& querySet,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances,
-    MetricType& metric) :
+    MetricType& metric,
+    const bool sameSet) :
     referenceSet(referenceSet),
     querySet(querySet),
     neighbors(neighbors),
     distances(distances),
     metric(metric),
+    sameSet(sameSet),
     lastQueryIndex(querySet.n_cols),
     lastReferenceIndex(referenceSet.n_cols),
     baseCases(0),
@@ -44,7 +46,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
-  if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+  if (sameSet && (queryIndex == referenceIndex))
     return 0.0;
 
   // If we have already performed this base case, then do not perform it again.



More information about the mlpack-git mailing list