[mlpack-git] master: Take a reference instead of a pointer to the query tree. (8062be1)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Aug 29 18:00:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1148f1652e139c9037eb3813550090313d089a30...a8a8a1381b529a01420de6e792a4a1e7bd58a626
>---------------------------------------------------------------
commit 8062be1a1d92c7c6ef8ca1f0ad37066fcebbc04a
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Mon Aug 29 18:48:23 2016 -0300
Take a reference instead of a pointer to the query tree.
>---------------------------------------------------------------
8062be1a1d92c7c6ef8ca1f0ad37066fcebbc04a
.../methods/neighbor_search/neighbor_search.hpp | 31 +++++++++++++++++++++-
.../neighbor_search/neighbor_search_impl.hpp | 23 ++++++++++++++--
.../methods/neighbor_search/ns_model_impl.hpp | 4 +--
src/mlpack/tests/knn_test.cpp | 2 +-
4 files changed, 54 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 313de90..0f92339 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -390,6 +390,35 @@ class NeighborSearch
* number of points in the query dataset and k is the number of neighbors
* being searched for.
*
+ * This method is deprecated and will be removed in mlpack 3.0.0! The Search()
+ * method taking a reference to the query tree is prefered.
+ *
+ * Note that if you are calling Search() multiple times with a single query
+ * tree, you need to reset the bounds in the statistic of each query node,
+ * otherwise the result may be wrong! You can do this by calling
+ * TreeType::Stat()::Reset() on each node in the query tree.
+ *
+ * @param queryTree Tree built on query points.
+ * @param k Number of neighbors to search for.
+ * @param neighbors Matrix storing lists of neighbors for each query point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ * @param sameSet Denotes whether or not the reference and query sets are the
+ * same.
+ */
+ mlpack_deprecated void Search(Tree* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ bool sameSet = false);
+
+ /**
+ * Given a pre-built query tree, search for the nearest neighbors of each
+ * point in the query tree, storing the output in the given matrices. The
+ * matrices will be set to the size of n columns by k rows, where n is the
+ * number of points in the query dataset and k is the number of neighbors
+ * being searched for.
+ *
* Note that if you are calling Search() multiple times with a single query
* tree, you need to reset the bounds in the statistic of each query node,
* otherwise the result may be wrong! You can do this by calling
@@ -403,7 +432,7 @@ class NeighborSearch
* @param sameSet Denotes whether or not the reference and query sets are the
* same.
*/
- void Search(Tree* queryTree,
+ void Search(Tree& queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 6c51d4c..fb099e3 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -834,6 +834,25 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
arma::mat& distances,
bool sameSet)
{
+ Search(*queryTree, k, neighbors, distances, sameSet);
+}
+
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Search(
+ Tree& queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ bool sameSet)
+{
// Update searchMode.
UpdateSearchMode();
@@ -856,7 +875,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
scores = 0;
// Get a reference to the query set.
- const MatType& querySet = queryTree->Dataset();
+ const MatType& querySet = queryTree.Dataset();
// We won't need to map query indices, but will we need to map distances?
arma::Mat<size_t>* neighborPtr = &neighbors;
@@ -874,7 +893,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
// Create the traverser.
DualTreeTraversalType<RuleType> traverser(rules);
- traverser.Traverse(*queryTree, *referenceTree);
+ traverser.Traverse(queryTree, *referenceTree);
scores += rules.Scores();
baseCases += rules.BaseCases();
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 5c16bca..c31cdc9 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -87,7 +87,7 @@ void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
// non overlapping (tau = 0).
typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
leafSize, rho);
- ns->Search(&queryTree, k, neighbors, distances);
+ ns->Search(queryTree, k, neighbors, distances);
}
else
ns->Search(querySet, k, neighbors, distances);
@@ -109,7 +109,7 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
arma::Mat<size_t> neighborsOut;
arma::mat distancesOut;
- ns->Search(&queryTree, k, neighborsOut, distancesOut);
+ ns->Search(queryTree, k, neighborsOut, distancesOut);
// Unmap the query points.
distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index 7939337..8587a20 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -178,7 +178,7 @@ BOOST_AUTO_TEST_CASE(EmptySearchTest)
std::invalid_argument);
BOOST_REQUIRE_THROW(empty.Search(5, neighbors, distances),
std::invalid_argument);
- BOOST_REQUIRE_THROW(empty.Search(&queryTree, 5, neighbors, distances),
+ BOOST_REQUIRE_THROW(empty.Search(queryTree, 5, neighbors, distances),
std::invalid_argument);
}
More information about the mlpack-git
mailing list