[mlpack-git] master: Preliminary refactoring of RangeSearch. (3673579)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 8 18:30:40 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/aaa19c9dcaed6e859d6565b8c1bfb07f755f0d7e...367357985828e5a007b3b9ccf6c279e983cd7fad
>---------------------------------------------------------------
commit 367357985828e5a007b3b9ccf6c279e983cd7fad
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Apr 8 18:30:24 2015 -0400
Preliminary refactoring of RangeSearch.
>---------------------------------------------------------------
367357985828e5a007b3b9ccf6c279e983cd7fad
src/mlpack/methods/range_search/range_search.hpp | 159 ++++-----
.../methods/range_search/range_search_impl.hpp | 377 ++++++++++++---------
.../methods/range_search/range_search_main.cpp | 115 +++----
src/mlpack/tests/range_search_test.cpp | 167 ++++-----
4 files changed, 428 insertions(+), 390 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 969166a..9504fd3 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -48,31 +48,6 @@ class RangeSearch
* @param metric Instantiated distance metric.
*/
RangeSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive = false,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RangeSearch object with only a reference set, which will
- * also be used as a query set. Optionally, perform the computation in naive
- * mode or single-tree mode, and set the leaf size used for tree-building.
- * Additionally an instantiated metric can be given, for cases where the
- * distance metric holds data.
- *
- * This method will copy the reference matrix to an internal copy, which is
- * rearranged during tree-building. You can avoid this extra copy by
- * pre-constructing the reference tree and passing it using a different
- * constructor.
- *
- * @param referenceSet Reference dataset.
- * @param naive Whether the computation should be done in O(n^2) naive mode.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param leafSize The leaf size to be used during tree construction.
- * @param metric Instantiated distance metric.
- */
- RangeSearch(const typename TreeType::Mat& referenceSet,
const bool naive = false,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -107,56 +82,97 @@ class RangeSearch
* @param metric Instantiated distance metric.
*/
RangeSearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
const bool singleMode = false,
const MetricType metric = MetricType());
/**
- * Initialize the RangeSearch object with the given reference dataset and
- * pre-constructed tree. It is assumed that the points in referenceSet
- * correspond to the points in referenceTree. Optionally, choose to use
- * single-tree mode. Naive mode is not available as an option for this
- * constructor; instead, to run naive computation, construct a tree with all
- * the points in one leaf (i.e. leafSize = number of points). Additionally,
- * an instantiated distance metric can be given, for the case where the
- * distance metric holds data.
+ * Destroy the RangeSearch object. If trees were created, they will be
+ * deleted.
+ */
+ ~RangeSearch();
+
+ /**
+ * Search for all reference points in the given range for each point in the
+ * query set, returning the results in the neighbors and distances objects.
+ * Each entry in the external vector corresponds to a query point. Each of
+ * these entries holds a vector which contains the indices and distances of
+ * the reference points falling into the given range.
*
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
+ * That is:
*
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
+ * - neighbors.size() and distances.size() both equal the number of query
+ * points.
*
- * @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
+ * - neighbors[i] contains the indices of all the points in the reference set
+ * which have distances inside the given range to query point i.
+ *
+ * - distances[i] contains all of the distances corresponding to the indices
+ * contained in neighbors[i].
+ *
+ * - neighbors[i] and distances[i] are not sorted in any particular order.
+ *
+ * @param querySet Set of query points to search with.
+ * @param range Range of distances in which to search.
+ * @param neighbors Object which will hold the list of neighbors for each
+ * point which fell into the given range, for each query point.
+ * @param distances Object which will hold the list of distances for each
+ * point which fell into the given range, for each query point.
*/
- RangeSearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
+ void Search(const typename TreeType::Mat& querySet,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances);
/**
- * Destroy the RangeSearch object. If trees were created, they will be
- * deleted.
+ * Given a pre-built query tree, search for all reference points in the given
+ * range for each point in the query set, returning the results in the
+ * neighbors and distances objects.
+ *
+ * Each entry in the external vector corresponds to a query point. Each of
+ * these entries holds a vector which contains the indices and distances of
+ * the reference points falling into the given range.
+ *
+ * That is:
+ *
+ * - neighbors.size() and distances.size() both equal the number of query
+ * points.
+ *
+ * - neighbors[i] contains the indices of all the points in the reference set
+ * which have distances inside the given range to query point i.
+ *
+ * - distances[i] contains all of the distances corresponding to the indices
+ * contained in neighbors[i].
+ *
+ * - neighbors[i] and distances[i] are not sorted in any particular order.
+ *
+ * If either naive or singleMode are set to true, this will throw an
+ * invalid_argument exception; passing in a query tree implies dual-tree
+ * search.
+ *
+ * If you want to use the reference tree as the query tree, instead call the
+ * overload of Search() that does not take a query set.
+ *
+ * @param queryTree Tree built on query points.
+ * @param range Range of distances in which to search.
+ * @param neighbors Object which will hold the list of neighbors for each
+ * point which fell into the given range, for each query point.
+ * @param distances Object which will hold the list of distances for each
+ * point which fell into the given range, for each query point.
*/
- ~RangeSearch();
+ void Search(TreeType* queryTree,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances);
/**
- * Search for all points in the given range, returning the results in the
- * neighbors and distances objects. Each entry in the external vector
- * corresponds to a query point. Each of these entries holds a vector which
- * contains the indices and distances of the reference points falling into the
- * given range.
+ * Search for all points in the given range for each point in the reference
+ * set (which was passed to the constructor), returning the results in the
+ * neighbors and distances objects. This means that the query set and the
+ * reference set are the same.
+ *
+ * Each entry in the external vector corresponds to a query point. Each of
+ * these entries holds a vector which contains the indices and distances of
+ * the reference points falling into the given range.
*
* That is:
*
@@ -171,6 +187,7 @@ class RangeSearch
*
* - neighbors[i] and distances[i] are not sorted in any particular order.
*
+ * @param queryTree Tree built on query points.
* @param range Range of distances in which to search.
* @param neighbors Object which will hold the list of neighbors for each
* point which fell into the given range, for each query point.
@@ -178,8 +195,8 @@ class RangeSearch
* point which fell into the given range, for each query point.
*/
void Search(const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances);
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances);
// Returns a string representation of this object.
std::string ToString() const;
@@ -187,23 +204,12 @@ class RangeSearch
private:
//! Copy of reference matrix; used when a tree is built internally.
typename TreeType::Mat referenceCopy;
- //! Copy of query matrix; used when a tree is built internally.
- typename TreeType::Mat queryCopy;
-
//! Reference set (data should be accessed using this).
const typename TreeType::Mat& referenceSet;
- //! Query set (data should be accessed using this).
- const typename TreeType::Mat& querySet;
-
//! Reference tree.
TreeType* referenceTree;
- //! Query tree (may be NULL).
- TreeType* queryTree;
-
//! Mappings to old reference indices (used when this object builds trees).
std::vector<size_t> oldFromNewReferences;
- //! Mappings to old query indices (used when this object builds trees).
- std::vector<size_t> oldFromNewQueries;
//! If true, this object is responsible for deleting the trees.
bool treeOwner;
@@ -218,9 +224,6 @@ class RangeSearch
//! Instantiated distance metric.
MetricType metric;
-
- //! The number of pruned nodes during computation.
- size_t numPrunes;
};
}; // namespace range
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 38e76f1..ae443ac 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -42,179 +42,91 @@ TreeType* BuildTree(
template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::RangeSearch(
const typename TreeType::Mat& referenceSetIn,
- const typename TreeType::Mat& querySetIn,
const bool naive,
const bool singleMode,
const MetricType metric) :
- referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
- : querySetIn),
+ referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+ ? referenceCopy : referenceSetIn),
referenceTree(NULL),
- queryTree(NULL),
treeOwner(!naive), // If in naive mode, we are not building any trees.
- hasQuerySet(true),
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
- metric(metric),
- numPrunes(0)
+ metric(metric)
{
// Build the trees.
Timer::Start("range_search/tree_building");
- // Copy the datasets, if they will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- referenceCopy = referenceSetIn;
- queryCopy = querySetIn;
- }
-
// If in naive mode, then we do not need to build trees.
if (!naive)
{
- // The const_cast is safe; if RearrangesDataset == false, then it'll be
- // casted back to const anyway, and if not, referenceSet points to
- // referenceCopy, which isn't const.
- referenceTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(referenceSet),
- oldFromNewReferences);
-
- if (!singleMode)
- queryTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
- }
+ // Copy the dataset, if it will be modified during tree building.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ referenceCopy = referenceSetIn;
- Timer::Stop("range_search/tree_building");
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- const typename TreeType::Mat& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(!naive), // If in naive mode, we are not building any trees.
- hasQuerySet(false),
- naive(naive),
- singleMode(!naive && singleMode), // Naive overrides single mode.
- metric(metric),
- numPrunes(0)
-{
- // Build the trees.
- Timer::Start("range_search/tree_building");
-
- // Copy the dataset, if it will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- referenceCopy = referenceSetIn;
-
- // If in naive mode, then we do not need to build trees.
- if (!naive)
- {
// The const_cast is safe; if RearrangesDataset == false, then it'll be
// casted back to const anyway, and if not, referenceSet points to
// referenceCopy, which isn't const.
referenceTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(referenceSet),
oldFromNewReferences);
-
- if (!singleMode)
- queryTree = new TreeType(*referenceTree);
}
+
Timer::Stop("range_search/tree_building");
}
template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::RangeSearch(
TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
const bool singleMode,
const MetricType metric) :
- referenceSet(referenceSet),
- querySet(querySet),
+ referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
- queryTree(queryTree),
treeOwner(false),
- hasQuerySet(true),
naive(false),
singleMode(singleMode),
- metric(metric),
- numPrunes(0)
+ metric(metric)
{
// Nothing else to initialize.
}
template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- treeOwner(false),
- hasQuerySet(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numPrunes(0)
-{
- // If doing dual-tree range search, we must clone the reference tree.
- if (!singleMode)
- queryTree = new TreeType(*referenceTree);
-}
-
-template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::~RangeSearch()
{
- if (treeOwner)
- {
- if (referenceTree)
- delete referenceTree;
- if (queryTree)
- delete queryTree;
- }
-
- // If doing dual-tree search with one dataset, we cloned the reference tree.
- if (!treeOwner && !hasQuerySet && !(singleMode || naive))
- delete queryTree;
+ if (treeOwner && referenceTree)
+ delete referenceTree;
}
template<typename MetricType, typename TreeType>
void RangeSearch<MetricType, TreeType>::Search(
+ const typename TreeType::Mat& querySet,
const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances)
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances)
{
Timer::Start("range_search/computing_neighbors");
- // Set size of prunes to 0.
- numPrunes = 0;
+ // This will hold mappings for query points, if necessary.
+ std::vector<size_t> oldFromNewQueries;
// If we have built the trees ourselves, then we will have to map all the
// indices back to their original indices when this computation is finished.
// To avoid extra copies, we will store the unmapped neighbors and distances
// in a separate object.
- std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
- std::vector<std::vector<double> >* distancePtr = &distances;
+ std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+ std::vector<std::vector<double>>* distancePtr = &distances;
// Mapping is only necessary if the tree rearranges points.
if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- if (treeOwner && !(singleMode && hasQuerySet))
- distancePtr = new std::vector<std::vector<double> >; // Query indices need to be mapped.
+ // Query indices only need to be mapped if we are building the query tree
+ // ourselves.
+ if (!singleMode && !naive)
+ distancePtr = new std::vector<std::vector<double>>;
+ // Reference indices only need to be mapped if we built the reference tree
+ // ourselves.
if (treeOwner)
- neighborPtr = new std::vector<std::vector<size_t> >; // All indices need mapping.
+ neighborPtr = new std::vector<std::vector<size_t>>;
}
// Resize each vector.
@@ -243,63 +155,224 @@ void RangeSearch<MetricType, TreeType>::Search(
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
-
- numPrunes = traverser.NumPrunes();
}
else // Dual-tree recursion.
{
+ // Build the query tree.
+ Timer::Stop("range_search/computing_neighbors");
+ Timer::Start("range_search/tree_building");
+ typename TreeType::Mat queryCopy;
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ queryCopy = querySet;
+
+ const typename TreeType::Mat& querySetRef =
+ (tree::TreeTraits<TreeType>::RearrangesDataset) ? querySet : queryCopy;
+ TreeType* queryTree = BuildTree<TreeType>(
+ const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+ Timer::Stop("range_search/tree_building");
+ Timer::Start("range_search/computing_neighbors");
+
// Create the traverser.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
- numPrunes = traverser.NumPrunes();
+ // Clean up tree memory.
+ delete queryTree;
}
Timer::Stop("range_search/computing_neighbors");
- // Output number of prunes.
- Log::Info << "Number of pruned nodes during computation: " << numPrunes
- << "." << std::endl;
-
// Map points back to original indices, if necessary.
-
- if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- // No mapping needed. We are done.
- return;
+ if (!singleMode && !naive && treeOwner)
+ {
+ // We must map both query and reference indices.
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ const size_t queryMapping = oldFromNewQueries[i];
+ distances[queryMapping] = (*distancePtr)[i];
+
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[queryMapping].resize(distances[queryMapping].size());
+ for (size_t j = 0; j < distances[queryMapping].size(); j++)
+ neighbors[queryMapping][j] =
+ oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (!singleMode && !naive)
+ {
+ // We must map query indices only.
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); ++i)
+ {
+ // Map distances and neighbors (copy a column).
+ const size_t queryMapping = oldFromNewQueries[i];
+ distances[queryMapping] = (*distancePtr)[i];
+ neighbors[queryMapping] = (*neighborPtr)[i];
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (treeOwner)
+ {
+ // We must map reference indices only.
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < neighbors.size(); i++)
+ {
+ neighbors[i].resize((*neighborPtr)[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+
+ // Finished with temporary object.
+ delete neighborPtr;
+ }
}
- else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+ TreeType* queryTree,
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances)
+{
+ Timer::Start("range_search/computing_neighbors");
+
+ // Get a reference to the query set.
+ const typename TreeType::Mat& querySet = queryTree->Dataset();
+
+ // Make sure we are in dual-tree mode.
+ if (singleMode || naive)
+ throw std::invalid_argument("cannot call RangeSearch::Search() with a "
+ "query tree when naive or singleMode are set to true");
+
+ // We won't need to map query indices, but will we need to map distances?
+ std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ neighborPtr = new std::vector<std::vector<size_t>>;
+
+ // Resize each vector.
+ neighborPtr->clear(); // Just in case there was anything in it.
+ neighborPtr->resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ // Create the helper object for the traversal.
+ typedef RangeSearchRules<MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+ distances, metric);
+
+ // Create the traverser.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*queryTree, *referenceTree);
+
+ Timer::Stop("range_search/computing_neighbors");
+
+ // Do we need to map indices?
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
{
+ // We must map reference indices only.
neighbors.clear();
neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
- for (size_t i = 0; i < distances.size(); i++)
+ for (size_t i = 0; i < neighbors.size(); i++)
{
- // Map distances (copy a column).
- size_t queryMapping = oldFromNewQueries[i];
- distances[queryMapping] = (*distancePtr)[i];
-
- // Copy each neighbor individually, because we need to map it.
- neighbors[queryMapping].resize(distances[queryMapping].size());
- for (size_t j = 0; j < distances[queryMapping].size(); j++)
- {
- neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
+ neighbors[i].resize((*neighborPtr)[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
}
- // Finished with temporary objects.
+ // Finished with temporary object.
delete neighborPtr;
- delete distancePtr;
}
- else if (treeOwner && !hasQuerySet)
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+ const math::Range& range,
+ std::vector<std::vector<size_t>>& neighbors,
+ std::vector<std::vector<double>>& distances)
+{
+ Timer::Start("range_search/computing_neighbors");
+
+ // Here, we will use the query set as the reference set.
+ std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
+ std::vector<std::vector<double>>* distancePtr = &distances;
+
+ if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+ {
+ // We will always need to rearrange in this case.
+ distancePtr = new std::vector<std::vector<double>>;
+ neighborPtr = new std::vector<std::vector<size_t>>;
+ }
+
+ // Resize each vector.
+ neighborPtr->clear(); // Just in case there was anything in it.
+ neighborPtr->resize(referenceSet.n_cols);
+ distancePtr->clear();
+ distancePtr->resize(referenceSet.n_cols);
+
+ // Create the helper object for the traversal.
+ typedef RangeSearchRules<MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
+ metric);
+
+ if (naive)
+ {
+ // The naive brute-force solution.
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ rules.BaseCase(i, j);
+ }
+ else if (singleMode)
+ {
+ // Create the traverser.
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+ }
+ else // Dual-tree recursion.
+ {
+ // Create the traverser.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*referenceTree, *referenceTree);
+ }
+
+ Timer::Stop("range_search/computing_neighbors");
+
+ // Do we need to map the reference indices?
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
{
neighbors.clear();
- neighbors.resize(querySet.n_cols);
+ neighbors.resize(referenceSet.n_cols);
distances.clear();
- distances.resize(querySet.n_cols);
+ distances.resize(referenceSet.n_cols);
for (size_t i = 0; i < distances.size(); i++)
{
@@ -319,24 +392,6 @@ void RangeSearch<MetricType, TreeType>::Search(
delete neighborPtr;
delete distancePtr;
}
- else if (treeOwner && hasQuerySet && singleMode) // Map only references.
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
-
- // Map indices of neighbors.
- for (size_t i = 0; i < neighbors.size(); i++)
- {
- neighbors[i].resize((*neighborPtr)[i].size());
- for (size_t j = 0; j < neighbors[i].size(); j++)
- {
- neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
- }
-
- // Finished with temporary object.
- delete neighborPtr;
- }
}
template<typename MetricType, typename TreeType>
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index fe91b61..31eba17 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -103,6 +103,7 @@ int main(int argc, char *argv[])
Log::Fatal << "Invalid range: maximum (" << max << ") must be greater than "
<< "minimum (" << min << ")." << endl;
}
+ const math::Range r(min, max);
// Sanity check on leaf size.
if (lsInt < 0)
@@ -110,7 +111,7 @@ int main(int argc, char *argv[])
Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
"than or equal to 0." << endl;
}
- size_t leafSize = lsInt;
+ const size_t leafSize = lsInt;
// Naive mode overrides single mode.
if (singleMode && naive)
@@ -118,9 +119,6 @@ int main(int argc, char *argv[])
Log::Warn << "--single_mode ignored because --naive is present." << endl;
}
- if (naive)
- leafSize = referenceData.n_cols;
-
if (coverTree && naive)
{
Log::Warn << "--cover_tree ignored because --naive is present." << endl;
@@ -131,106 +129,92 @@ int main(int argc, char *argv[])
vector<vector<double> > distances;
// The cover tree implies different types, so we must split this section.
- if (coverTree)
+ if (naive)
+ {
+ Log::Info << "Performing naive search (no trees)." << endl;
+
+ // Trees don't matter.
+ RangeSearch<> rangeSearch(referenceData, singleMode, naive);
+ rangeSearch.Search(queryData, r, neighbors, distances);
+ }
+ else if (coverTree)
{
Log::Info << "Using cover trees." << endl;
// This is significantly simpler than kd-tree construction because the data
// matrix is not modified.
- RSCoverType* rangeSearch = NULL;
- CoverTreeType referenceTree(referenceData);
- CoverTreeType* queryTree = NULL;
+ RSCoverType rangeSearch(referenceData, singleMode);
if (CLI::GetParam<string>("query_file") == "")
{
// Single dataset.
- rangeSearch = new RSCoverType(&referenceTree, referenceData, singleMode);
+ rangeSearch.Search(r, neighbors, distances);
}
else
{
// Two datasets.
const string queryFile = CLI::GetParam<string>("query_file");
data::Load(queryFile, queryData, true);
- queryTree = new CoverTreeType(queryData);
- rangeSearch = new RSCoverType(&referenceTree, queryTree, referenceData,
- queryData, singleMode);
+ // Query tree is automatically built if needed.
+ rangeSearch.Search(queryData, r, neighbors, distances);
}
-
- Log::Info << "Trees built." << endl;
-
- const math::Range r(min, max);
- rangeSearch->Search(r, neighbors, distances);
-
- if (queryTree)
- delete queryTree;
- delete rangeSearch;
}
else
{
- // Because we may construct it differently, we need a pointer.
- RSType* rangeSearch = NULL;
+ typedef BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat> TreeType;
- // Mappings for when we build the tree.
- vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
+ // Track mappings.
Log::Info << "Building reference tree..." << endl;
Timer::Start("tree_building");
-
- BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat>
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat>*
- queryTree = NULL; // Empty for now.
-
+ vector<size_t> oldFromNewRefs;
+ vector<size_t> oldFromNewQueries; // Not used yet.
+ TreeType refTree(referenceData, oldFromNewRefs, leafSize);
Timer::Stop("tree_building");
- vector<size_t> oldFromNewQueries;
+ // Collect the results in these vectors before remapping.
+ vector<vector<double> > distancesOut;
+ vector<vector<size_t> > neighborsOut;
+
+ RSType rangeSearch(&refTree, singleMode);
if (CLI::GetParam<string>("query_file") != "")
{
const string queryFile = CLI::GetParam<string>("query_file");
data::Load(queryFile, queryData, true);
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
- Log::Info << "Building query tree..." << endl;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- RangeSearchStat>(queryData, oldFromNewQueries, leafSize);
+ if (singleMode)
+ {
+ Log::Info << "Computing neighbors within range [" << min << ", " << max
+ << "]." << endl;
+ rangeSearch.Search(queryData, r, neighborsOut, distancesOut);
+ }
+ else
+ {
+ Log::Info << "Building query tree..." << endl;
- Timer::Stop("tree_building");
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
+ TreeType queryTree(queryData, oldFromNewQueries, leafSize);
+ Timer::Stop("tree_building");
- rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
- singleMode);
+ Log::Info << "Tree built." << endl;
- Log::Info << "Tree built." << endl;
+ Log::Info << "Computing neighbors within range [" << min << ", " << max
+ << "]." << endl;
+ rangeSearch.Search(&queryTree, r, neighborsOut, distancesOut);
+ }
}
else
{
- rangeSearch = new RSType(&refTree, referenceData, singleMode);
-
- Log::Info << "Trees built." << endl;
+ Log::Info << "Computing neighbors within range [" << min << ", " << max
+ << "]." << endl;
+ rangeSearch.Search(r, neighborsOut, distancesOut);
}
- Log::Info << "Computing neighbors within range [" << min << ", " << max
- << "]." << endl;
-
- // Collect the results in these vectors before remapping.
- vector<vector<double> > distancesOut;
- vector<vector<size_t> > neighborsOut;
-
- const math::Range r(min, max);
- rangeSearch->Search(r, neighborsOut, distancesOut);
-
Log::Info << "Neighbors computed." << endl;
// We have to map back to the original indices from before the tree
@@ -272,11 +256,6 @@ int main(int argc, char *argv[])
}
}
}
-
- // Clean up.
- if (queryTree)
- delete queryTree;
- delete rangeSearch;
}
// Save output. We have to do this by hand.
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index b497d66..f1cd60d 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -21,9 +21,9 @@ BOOST_AUTO_TEST_SUITE(RangeSearchTest);
// Get our results into a sorted format, so we can actually then test for
// correctness.
-void SortResults(const vector<vector<size_t> >& neighbors,
- const vector<vector<double> >& distances,
- vector<vector<pair<double, size_t> > >& output)
+void SortResults(const vector<vector<size_t>>& neighbors,
+ const vector<vector<double>>& distances,
+ vector<vector<pair<double, size_t>>>& output)
{
output.resize(neighbors.size());
for (size_t i = 0; i < neighbors.size(); i++)
@@ -90,20 +90,20 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
rs = new RangeSearch<>(dataMutable, true);
break;
case 1: // Use the single-tree method.
- rs = new RangeSearch<>(tree, dataMutable, true);
+ rs = new RangeSearch<>(tree, true);
break;
case 2: // Use the dual-tree method.
- rs = new RangeSearch<>(tree, dataMutable);
+ rs = new RangeSearch<>(tree);
break;
}
// Now perform the first calculation. Points within 0.50.
- vector<vector<size_t> > neighbors;
- vector<vector<double> > distances;
+ vector<vector<size_t>> neighbors;
+ vector<vector<double>> distances;
rs->Search(Range(0.0, sqrt(0.5)), neighbors, distances);
// Now the exhaustive check for correctness. This will be long.
- vector<vector<pair<double, size_t> > > sortedOutput;
+ vector<vector<pair<double, size_t>>> sortedOutput;
SortResults(neighbors, distances, sortedOutput);
BOOST_REQUIRE(sortedOutput[newFromOld[0]].size() == 4);
@@ -460,20 +460,20 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
arma::mat naiveQuery(dataForTree);
arma::mat naiveReferences(dataForTree);
- RangeSearch<> rs(dualQuery, dualReferences);
+ RangeSearch<> rs(dualReferences);
- RangeSearch<> naive(naiveQuery, naiveReferences, true);
+ RangeSearch<> naive(naiveReferences, true);
- vector<vector<size_t> > neighborsTree;
- vector<vector<double> > distancesTree;
- rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
- vector<vector<pair<double, size_t> > > sortedTree;
+ vector<vector<size_t>> neighborsTree;
+ vector<vector<double>> distancesTree;
+ rs.Search(dualQuery, Range(0.25, 1.05), neighborsTree, distancesTree);
+ vector<vector<pair<double, size_t>>> sortedTree;
SortResults(neighborsTree, distancesTree, sortedTree);
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
- naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
+ vector<vector<size_t>> neighborsNaive;
+ vector<vector<double>> distancesNaive;
+ naive.Search(naiveQuery, Range(0.25, 1.05), neighborsNaive, distancesNaive);
+ vector<vector<pair<double, size_t>>> sortedNaive;
SortResults(neighborsNaive, distancesNaive, sortedNaive);
for (size_t i = 0; i < sortedTree.size(); i++)
@@ -513,16 +513,16 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
// Set naive mode.
RangeSearch<> naive(naiveQuery, true);
- vector<vector<size_t> > neighborsTree;
- vector<vector<double> > distancesTree;
+ vector<vector<size_t>> neighborsTree;
+ vector<vector<double>> distancesTree;
rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
- vector<vector<pair<double, size_t> > > sortedTree;
+ vector<vector<pair<double, size_t>>> sortedTree;
SortResults(neighborsTree, distancesTree, sortedTree);
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
+ vector<vector<size_t>> neighborsNaive;
+ vector<vector<double>> distancesNaive;
naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
+ vector<vector<pair<double, size_t>>> sortedNaive;
SortResults(neighborsNaive, distancesNaive, sortedNaive);
for (size_t i = 0; i < sortedTree.size(); i++)
@@ -562,16 +562,16 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
// Set up computation for naive mode.
RangeSearch<> naive(naiveQuery, true);
- vector<vector<size_t> > neighborsSingle;
- vector<vector<double> > distancesSingle;
+ vector<vector<size_t>> neighborsSingle;
+ vector<vector<double>> distancesSingle;
single.Search(Range(0.25, 1.05), neighborsSingle, distancesSingle);
- vector<vector<pair<double, size_t> > > sortedTree;
+ vector<vector<pair<double, size_t>>> sortedTree;
SortResults(neighborsSingle, distancesSingle, sortedTree);
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
+ vector<vector<size_t>> neighborsNaive;
+ vector<vector<double>> distancesNaive;
naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
+ vector<vector<pair<double, size_t>>> sortedNaive;
SortResults(neighborsNaive, distancesNaive, sortedNaive);
for (size_t i = 0; i < sortedTree.size(); i++)
@@ -600,8 +600,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
RangeSearchStat> CoverTreeType;
CoverTreeType tree(data);
- RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree,
- data);
+ RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -631,12 +630,12 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for cover tree search.
- vector<vector<size_t> > coverNeighbors;
- vector<vector<double> > coverDistances;
+ vector<vector<size_t>> coverNeighbors;
+ vector<vector<double>> coverDistances;
// Clean the tree statistics.
CleanTree(tree);
@@ -646,8 +645,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
coversearch.Search(range, coverNeighbors, coverDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > coverSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> coverSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(coverNeighbors, coverDistances, coverSorted);
@@ -680,16 +679,16 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
RangeSearchStat> CoverTreeType;
CoverTreeType tree(data);
- CoverTreeType queryTree(queries);
+ CoverTreeType* queryTree = new CoverTreeType(queries);
RangeSearch<metric::EuclideanDistance, CoverTreeType>
- coversearch(&tree, &queryTree, data, queries);
+ coversearch(&tree);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
{
// Set up kd-tree range search. We don't have an easy way to rebuild the
// tree, so we'll just reinstantiate it here each loop time.
- RangeSearch<> kdsearch(data, queries);
+ RangeSearch<> kdsearch(data);
Range range;
switch (r)
@@ -713,24 +712,25 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for cover tree search.
- vector<vector<size_t> > coverNeighbors;
- vector<vector<double> > coverDistances;
+ vector<vector<size_t>> coverNeighbors;
+ vector<vector<double>> coverDistances;
// Clean the trees.
CleanTree(tree);
- CleanTree(queryTree);
+ delete queryTree;
+ queryTree = new CoverTreeType(queries);
// Run the searches.
- coversearch.Search(range, coverNeighbors, coverDistances);
- kdsearch.Search(range, kdNeighbors, kdDistances);
+ coversearch.Search(queryTree, range, coverNeighbors, coverDistances);
+ kdsearch.Search(queries, range, kdNeighbors, kdDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > coverSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> coverSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(coverNeighbors, coverDistances, coverSorted);
@@ -746,6 +746,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
}
}
+
+ delete queryTree;
}
/**
@@ -760,8 +762,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
RangeSearchStat> CoverTreeType;
CoverTreeType tree(data);
- RangeSearch<metric::EuclideanDistance, CoverTreeType>
- coversearch(&tree, data, true);
+ RangeSearch<metric::EuclideanDistance, CoverTreeType> coversearch(&tree,
+ true);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -791,12 +793,12 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for cover tree search.
- vector<vector<size_t> > coverNeighbors;
- vector<vector<double> > coverDistances;
+ vector<vector<size_t>> coverNeighbors;
+ vector<vector<double>> coverDistances;
// Clean the tree statistics.
CleanTree(tree);
@@ -806,8 +808,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
coversearch.Search(range, coverNeighbors, coverDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > coverSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> coverSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(coverNeighbors, coverDistances, coverSorted);
@@ -837,7 +839,7 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
TreeType tree(data);
RangeSearch<metric::EuclideanDistance, TreeType>
- ballsearch(&tree, data, true);
+ ballsearch(&tree, true);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -867,12 +869,12 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for ball tree search.
- vector<vector<size_t> > ballNeighbors;
- vector<vector<double> > ballDistances;
+ vector<vector<size_t>> ballNeighbors;
+ vector<vector<double>> ballDistances;
// Clean the tree statistics.
CleanTree(tree);
@@ -882,8 +884,8 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
ballsearch.Search(range, ballNeighbors, ballDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > ballSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> ballSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(ballNeighbors, ballDistances, ballSorted);
@@ -913,7 +915,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
// Set up ball tree range search.
typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
TreeType tree(data);
- RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree, data);
+ RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -943,12 +945,12 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for ball tree search.
- vector<vector<size_t> > ballNeighbors;
- vector<vector<double> > ballDistances;
+ vector<vector<size_t>> ballNeighbors;
+ vector<vector<double>> ballDistances;
// Clean the tree statistics.
CleanTree(tree);
@@ -958,8 +960,8 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
ballsearch.Search(range, ballNeighbors, ballDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > ballSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> ballSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(ballNeighbors, ballDistances, ballSorted);
@@ -993,15 +995,14 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
TreeType tree(data);
TreeType queryTree(queries);
- RangeSearch<metric::EuclideanDistance, TreeType>
- ballsearch(&tree, &queryTree, data, queries);
+ RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
{
// Set up kd-tree range search. We don't have an easy way to rebuild the
// tree, so we'll just reinstantiate it here each loop time.
- RangeSearch<> kdsearch(data, queries);
+ RangeSearch<> kdsearch(data);
Range range;
switch (r)
@@ -1025,24 +1026,24 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
}
// Results for kd-tree search.
- vector<vector<size_t> > kdNeighbors;
- vector<vector<double> > kdDistances;
+ vector<vector<size_t>> kdNeighbors;
+ vector<vector<double>> kdDistances;
// Results for ball tree search.
- vector<vector<size_t> > ballNeighbors;
- vector<vector<double> > ballDistances;
+ vector<vector<size_t>> ballNeighbors;
+ vector<vector<double>> ballDistances;
// Clean the trees.
CleanTree(tree);
CleanTree(queryTree);
// Run the searches.
- ballsearch.Search(range, ballNeighbors, ballDistances);
- kdsearch.Search(range, kdNeighbors, kdDistances);
+ ballsearch.Search(&queryTree, range, ballNeighbors, ballDistances);
+ kdsearch.Search(queries, range, kdNeighbors, kdDistances);
// Sort before comparison.
- vector<vector<pair<double, size_t> > > kdSorted;
- vector<vector<pair<double, size_t> > > ballSorted;
+ vector<vector<pair<double, size_t>>> kdSorted;
+ vector<vector<pair<double, size_t>>> ballSorted;
SortResults(kdNeighbors, kdDistances, kdSorted);
SortResults(ballNeighbors, ballDistances, ballSorted);
More information about the mlpack-git
mailing list