[mlpack-git] master: Take a reference instead of a pointer to the query tree. (8062be1)

gitdub at mlpack.org gitdub at mlpack.org
Mon Aug 29 18:00:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1148f1652e139c9037eb3813550090313d089a30...a8a8a1381b529a01420de6e792a4a1e7bd58a626

>---------------------------------------------------------------

commit 8062be1a1d92c7c6ef8ca1f0ad37066fcebbc04a
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Mon Aug 29 18:48:23 2016 -0300

    Take a reference instead of a pointer to the query tree.


>---------------------------------------------------------------

8062be1a1d92c7c6ef8ca1f0ad37066fcebbc04a
 .../methods/neighbor_search/neighbor_search.hpp    | 31 +++++++++++++++++++++-
 .../neighbor_search/neighbor_search_impl.hpp       | 23 ++++++++++++++--
 .../methods/neighbor_search/ns_model_impl.hpp      |  4 +--
 src/mlpack/tests/knn_test.cpp                      |  2 +-
 4 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 313de90..0f92339 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -390,6 +390,35 @@ class NeighborSearch
    * number of points in the query dataset and k is the number of neighbors
    * being searched for.
    *
+   * This method is deprecated and will be removed in mlpack 3.0.0! The Search()
+   * method taking a reference to the query tree is prefered.
+   *
+   * Note that if you are calling Search() multiple times with a single query
+   * tree, you need to reset the bounds in the statistic of each query node,
+   * otherwise the result may be wrong!  You can do this by calling
+   * TreeType::Stat()::Reset() on each node in the query tree.
+   *
+   * @param queryTree Tree built on query points.
+   * @param k Number of neighbors to search for.
+   * @param neighbors Matrix storing lists of neighbors for each query point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *      point.
+   * @param sameSet Denotes whether or not the reference and query sets are the
+   *      same.
+   */
+  mlpack_deprecated void Search(Tree* queryTree,
+                                const size_t k,
+                                arma::Mat<size_t>& neighbors,
+                                arma::mat& distances,
+                                bool sameSet = false);
+
+  /**
+   * Given a pre-built query tree, search for the nearest neighbors of each
+   * point in the query tree, storing the output in the given matrices.  The
+   * matrices will be set to the size of n columns by k rows, where n is the
+   * number of points in the query dataset and k is the number of neighbors
+   * being searched for.
+   *
    * Note that if you are calling Search() multiple times with a single query
    * tree, you need to reset the bounds in the statistic of each query node,
    * otherwise the result may be wrong!  You can do this by calling
@@ -403,7 +432,7 @@ class NeighborSearch
    * @param sameSet Denotes whether or not the reference and query sets are the
    *      same.
    */
-  void Search(Tree* queryTree,
+  void Search(Tree& queryTree,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances,
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 6c51d4c..fb099e3 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -834,6 +834,25 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::mat& distances,
     bool sameSet)
 {
+  Search(*queryTree, k, neighbors, distances, sameSet);
+}
+
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Search(
+    Tree& queryTree,
+    const size_t k,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances,
+    bool sameSet)
+{
   // Update searchMode.
   UpdateSearchMode();
 
@@ -856,7 +875,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
   scores = 0;
 
   // Get a reference to the query set.
-  const MatType& querySet = queryTree->Dataset();
+  const MatType& querySet = queryTree.Dataset();
 
   // We won't need to map query indices, but will we need to map distances?
   arma::Mat<size_t>* neighborPtr = &neighbors;
@@ -874,7 +893,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
 
   // Create the traverser.
   DualTreeTraversalType<RuleType> traverser(rules);
-  traverser.Traverse(*queryTree, *referenceTree);
+  traverser.Traverse(queryTree, *referenceTree);
 
   scores += rules.Scores();
   baseCases += rules.BaseCases();
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 5c16bca..c31cdc9 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -87,7 +87,7 @@ void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
       // non overlapping (tau = 0).
       typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
           leafSize, rho);
-      ns->Search(&queryTree, k, neighbors, distances);
+      ns->Search(queryTree, k, neighbors, distances);
     }
     else
       ns->Search(querySet, k, neighbors, distances);
@@ -109,7 +109,7 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
 
     arma::Mat<size_t> neighborsOut;
     arma::mat distancesOut;
-    ns->Search(&queryTree, k, neighborsOut, distancesOut);
+    ns->Search(queryTree, k, neighborsOut, distancesOut);
 
     // Unmap the query points.
     distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index 7939337..8587a20 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -178,7 +178,7 @@ BOOST_AUTO_TEST_CASE(EmptySearchTest)
       std::invalid_argument);
   BOOST_REQUIRE_THROW(empty.Search(5, neighbors, distances),
       std::invalid_argument);
-  BOOST_REQUIRE_THROW(empty.Search(&queryTree, 5, neighbors, distances),
+  BOOST_REQUIRE_THROW(empty.Search(queryTree, 5, neighbors, distances),
       std::invalid_argument);
 }
 




More information about the mlpack-git mailing list