[mlpack-git] master: Refactor NeighborSearch like RangeSearch. Now Search() takes the query set. (80650d4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 22 16:32:27 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8f85309ae9be40e819b301b39c9a940aa28f3bb2...57d0567dddff01feea73b348f38cc040dc3cf8e3
>---------------------------------------------------------------
commit 80650d46316345367ec9b88fbd45df27beb854ee
Author: ryan <ryan at ratml.org>
Date: Wed Apr 22 14:18:21 2015 -0400
Refactor NeighborSearch like RangeSearch.
Now Search() takes the query set.
>---------------------------------------------------------------
80650d46316345367ec9b88fbd45df27beb854ee
.../methods/neighbor_search/neighbor_search.hpp | 154 +++-----
.../neighbor_search/neighbor_search_impl.hpp | 405 +++++++++++----------
.../neighbor_search/neighbor_search_rules.hpp | 6 +-
.../neighbor_search/neighbor_search_rules_impl.hpp | 6 +-
4 files changed, 284 insertions(+), 287 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 9ae8b6a..18111a3 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -50,66 +50,35 @@ class NeighborSearch
{
public:
/**
- * Initialize the NeighborSearch object, passing both a query and reference
- * dataset. Optionally, perform the computation in naive mode or single-tree
- * mode, and set the leaf size used for tree-building. An initialized
- * distance metric can be given, for cases where the metric has internal data
- * (i.e. the distance::MahalanobisDistance class).
+ * Initialize the NeighborSearch object, passing a reference dataset (this is
+ * the dataset which is searched). Optionally, perform the computation in
+ * naive mode or single-tree mode. An initialized distance metric can be
+ * given, for cases where the metric has internal data (i.e. the
+ * distance::MahalanobisDistance class).
*
* This method will copy the matrices to internal copies, which are rearranged
* during tree-building. You can avoid this extra copy by pre-constructing
* the trees and passing them using a diferent constructor.
*
* @param referenceSet Set of reference points.
- * @param querySet Set of query points.
* @param naive If true, O(n^2) naive search will be used (as opposed to
* dual-tree search). This overrides singleMode (if it is set to true).
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
* @param metric An optional instance of the MetricType class.
*/
NeighborSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
const bool naive = false,
const bool singleMode = false,
const MetricType metric = MetricType());
/**
- * Initialize the NeighborSearch object, passing only one dataset, which is
- * used as both the query and the reference dataset. Optionally, perform the
- * computation in naive mode or single-tree mode, and set the leaf size used
- * for tree-building. An initialized distance metric can be given, for cases
- * where the metric has internal data (i.e. the distance::MahalanobisDistance
- * class).
- *
- * If naive mode is being used and a pre-built tree is given, it may not work:
- * naive mode operates by building a one-node tree (the root node holds all
- * the points). If that condition is not satisfied with the pre-built tree,
- * then naive mode will not work.
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
- * @param metric An optional instance of the MetricType class.
- */
- NeighborSearch(const typename TreeType::Mat& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the NeighborSearch object with the given datasets and
- * pre-constructed trees. It is assumed that the points in referenceSet and
- * querySet correspond to the points in referenceTree and queryTree,
- * respectively. Optionally, choose to use single-tree mode. Naive mode is
- * not available as an option for this constructor; instead, to run naive
- * computation, construct a tree with all of the points in one leaf (i.e.
- * leafSize = number of points). Additionally, an instantiated distance
- * metric can be given, for cases where the distance metric holds data.
+ * Initialize the NeighborSearch object with the given pre-constructed
+ * reference tree (this is the tree built on the points that will be
+ * searched). Optionally, choose to use single-tree mode. Naive mode is not
+ * available as an option for this constructor. Additionally, an instantiated
+ * distance metric can be given, for cases where the distance metric holds
+ * data.
*
* There is no copying of the data matrices in this constructor (because
* tree-building is not necessary), so this is the constructor to use when
@@ -123,73 +92,71 @@ class NeighborSearch
* @endnote
*
* @param referenceTree Pre-built tree for reference points.
- * @param queryTree Pre-built tree for query points.
* @param referenceSet Set of reference points corresponding to referenceTree.
- * @param querySet Set of query points corresponding to queryTree.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
NeighborSearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
const bool singleMode = false,
const MetricType metric = MetricType());
/**
- * Initialize the NeighborSearch object with the given reference dataset and
- * pre-constructed tree. It is assumed that the points in referenceSet
- * correspond to the points in referenceTree. Optionally, choose to use
- * single-tree mode. Naive mode is not available as an option for this
- * constructor; instead, to run naive computation, construct a tree with all
- * the points in one leaf (i.e. leafSize = number of points). Additionally,
- * an instantiated distance metric can be given, for the case where the
- * distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- NeighborSearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
-
- /**
* Delete the NeighborSearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
~NeighborSearch();
/**
- * Compute the nearest neighbors and store the output in the given matrices.
- * The matrices will be set to the size of n columns by k rows, where n is the
+ * For each point in the query set, compute the nearest neighbors and store
+ * the output in the given matrices. The matrices will be set to the size of
+ * n columns by k rows, where n is the number of points in the query dataset
+ * and k is the number of neighbors being searched for.
+ *
+ * @param querySet Set of query points (can be just one point).
+ * @param k Number of neighbors to search for.
+ * @param neighbors Matrix storing lists of neighbors for each query point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ */
+ void Search(const typename TreeType::Mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ /**
+ * Given a pre-built query tree, search for the nearest neighbors of each
+ * point in the query tree, storing the output in the given matrices. The
+ * matrices will be set to the size of n columns by k rows, where n is the
* number of points in the query dataset and k is the number of neighbors
* being searched for.
*
+ * @param queryTree Tree built on query points.
* @param k Number of neighbors to search for.
- * @param resultingNeighbors Matrix storing lists of neighbors for each query
- * point.
+ * @param neighbors Matrix storing lists of neighbors for each query point.
* @param distances Matrix storing distances of neighbors for each query
- * point.
+ * point.
+ */
+ void Search(TreeType* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ /**
+ * Search for the nearest neighbors of every point in the reference set. This
+ * is basically equivalent to calling any other overload of Search() with the
+ * reference set as the query set; so, this lets you do all-k-nearest-neighbors
+ * search. The results are stored in the given matrices. The matrices will
+ * be set to the size of n columns by k rows, where n is the number of points
+ * in the query dataset and k is the number of neighbors being searched for.
+ *
+ * @param k Number of neighbors to search for.
+ * @param neighbors Matrix storing lists of neighbors for each query point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
*/
void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
+ arma::Mat<size_t>& neighbors,
arma::mat& distances);
//! Returns a string representation of this object.
@@ -210,18 +177,12 @@ class NeighborSearch
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
typename TreeType::Mat referenceCopy;
- //! Copy of query dataset (if we need it, because tree building modifies it).
- typename TreeType::Mat queryCopy;
-
//! Reference dataset.
const typename TreeType::Mat& referenceSet;
- //! Query dataset (may not be given).
- const typename TreeType::Mat& querySet;
-
//! Pointer to the root of the reference tree.
TreeType* referenceTree;
- //! Pointer to the root of the query tree (might not exist).
- TreeType* queryTree;
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> oldFromNewReferences;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
@@ -236,11 +197,6 @@ class NeighborSearch
//! Instantiation of metric.
MetricType metric;
- //! Permutations of reference points during tree building.
- std::vector<size_t> oldFromNewReferences;
- //! Permutations of query points during tree building.
- std::vector<size_t> oldFromNewQueries;
-
//! The total number of base cases.
size_t baseCases;
//! The total number of scores (applicable for non-naive search).
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 06fbcb2..5a1c3c6 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -43,96 +43,34 @@ TreeType* BuildTree(
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::
NeighborSearch(const typename TreeType::Mat& referenceSetIn,
- const typename TreeType::Mat& querySetIn,
const bool naive,
const bool singleMode,
const MetricType metric) :
- referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
- : querySetIn),
+ referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+ ? referenceCopy : referenceSetIn),
referenceTree(NULL),
- queryTree(NULL),
treeOwner(!naive), // False if a tree was passed. If naive, then no trees.
- hasQuerySet(true),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
metric(metric),
baseCases(0),
scores(0)
{
- // C++11 will allow us to call out to other constructors so we can avoid this
- // copypasta problem.
-
- // We'll time tree building, but only if we are building trees.
+ // Build the tree.
Timer::Start("tree_building");
- // Copy the datasets, if they will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- referenceCopy = referenceSetIn;
- queryCopy = querySetIn;
- }
-
- // If not in naive mode, then we need to build trees.
if (!naive)
{
- // The const_cast is safe; if RearrangesDataset == false, then it'll be
- // casted back to const anyway, and if not, referenceSet points to
- // referenceCopy, which isn't const.
- referenceTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(referenceSet),
- oldFromNewReferences);
-
- if (!singleMode)
- queryTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
- }
-
- // Stop the timer we started above (if we need to).
- Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
- : referenceSetIn),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(!naive), // If naive, then we are not building any trees.
- hasQuerySet(false),
- naive(naive),
- singleMode(!naive && singleMode), // No single mode if naive.
- metric(metric),
- baseCases(0),
- scores(0)
-{
- // We'll time tree building, but only if we are building trees.
- Timer::Start("tree_building");
-
- // Copy the dataset, if it will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- referenceCopy = referenceSetIn;
+ // Copy the dataset, if it will be modified during tree building.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ referenceCopy = referenceSetIn;
- // If not in naive mode, then we may need to construct trees.
- if (!naive)
- {
// The const_cast is safe; if RearrangesDataset == false, then it'll be
// casted back to const anyway, and if not, referenceSet points to
// referenceCopy, which isn't const.
referenceTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(referenceSet),
oldFromNewReferences);
-
- if (!singleMode)
- queryTree = new TreeType(*referenceTree);
}
// Stop the timer we started above.
@@ -143,17 +81,11 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
const bool singleMode,
const MetricType metric) :
- referenceSet(referenceSet),
- querySet(querySet),
+ referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
- queryTree(queryTree),
treeOwner(false),
- hasQuerySet(true),
naive(false),
singleMode(singleMode),
metric(metric),
@@ -163,53 +95,12 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
// Nothing else to initialize.
}
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
- TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- treeOwner(false),
- hasQuerySet(false), // In this case we will own a tree, if singleMode.
- naive(false),
- singleMode(singleMode),
- metric(metric),
- baseCases(0),
- scores(0)
-{
- Timer::Start("tree_building");
-
- // The query tree cannot be the same as the reference tree.
- if (referenceTree && !singleMode)
- queryTree = new TreeType(*referenceTree);
-
- Timer::Stop("tree_building");
-}
-
-/**
- * The tree is the only member we may be responsible for deleting. The others
- * will take care of themselves.
- */
+// Clean memory.
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
{
- if (treeOwner)
- {
- if (referenceTree)
- delete referenceTree;
- if (queryTree)
- delete queryTree;
- }
- else if (!treeOwner && !hasQuerySet && !(singleMode || naive))
- {
- // We replicated the reference tree to create a query tree.
- delete queryTree;
- }
+ if (treeOwner && referenceTree)
+ delete referenceTree;
}
/**
@@ -218,23 +109,27 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+ const typename TreeType::Mat& querySet,
const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
+ arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
Timer::Start("computing_neighbors");
+ // This will hold mappings for query points, if necessary.
+ std::vector<size_t> oldFromNewQueries;
+
// If we have built the trees ourselves, then we will have to map all the
// indices back to their original indices when this computation is finished.
// To avoid an extra copy, we will store the neighbors and distances in a
// separate matrix.
- arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+ arma::Mat<size_t>* neighborPtr = &neighbors;
arma::mat* distancePtr = &distances;
// Mapping is only necessary if the tree rearranges points.
if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- if (treeOwner && !(singleMode && hasQuerySet))
+ if (!singleMode && !naive)
distancePtr = new arma::mat; // Query indices need to be mapped.
if (treeOwner)
@@ -247,9 +142,20 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
+ // If we will be building a tree and it will modify the query set, make a copy
+ // of the dataset.
+ typename TreeType::Mat queryCopy;
+ const bool needsCopy = (!naive && !singleMode &&
+ tree::TreeTraits<TreeType>::RearrangesDataset);
+ if (needsCopy)
+ queryCopy = querySet;
+
+ const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
+ querySet;
+
// Create the helper object for the tree traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+ RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr, metric);
if (naive)
{
@@ -262,10 +168,6 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
else if (singleMode)
{
- // The search doesn't work if the root node is also a leaf node.
- // If this is the case, it is suggested that you use the naive method.
- Log::Assert(!(referenceTree->IsLeaf()));
-
// Create the traverser.
typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -281,6 +183,14 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
else // Dual-tree recursion.
{
+ // Build the query tree.
+ Timer::Stop("computing_neighbors");
+ Timer::Start("tree_building");
+ TreeType* queryTree = BuildTree<TreeType>(
+ const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+ Timer::Stop("tree_building");
+ Timer::Start("computing_neighbors");
+
// Create the traverser.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -295,92 +205,217 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
Timer::Stop("computing_neighbors");
- // Now, do we need to do mapping of indices?
- if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- // No mapping needed. We are done.
- return;
- }
- else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
+ // Map points back to original indices, if necessary.
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- // Set size of output matrices correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
+ if (!singleMode && !naive && treeOwner)
{
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+ // We must map both query and reference indices.
+ neighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
+ for (size_t i = 0; i < distances.n_cols; i++)
{
- resultingNeighbors(j, oldFromNewQueries[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ neighbors(j, oldFromNewQueries[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
}
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
}
+ else if (!singleMode && !naive)
+ {
+ // We must map query indices only.
+ neighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
- // Finished with temporary matrices.
- delete neighborPtr;
- delete distancePtr;
- }
- else if (treeOwner && !hasQuerySet)
- {
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
+ for (size_t i = 0; i < distances.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ const size_t queryMapping = oldFromNewQueries[i];
+ distances.col(queryMapping) = distancePtr->col(i);
+ neighbors.col(queryMapping) = neighborPtr->col(i);
+ }
- for (size_t i = 0; i < distances.n_cols; i++)
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (treeOwner)
{
- // Map distances (copy a column).
- distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+ // We must map reference indices only.
+ neighbors.set_size(k, querySet.n_cols);
// Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewReferences[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
+ for (size_t i = 0; i < neighbors.n_cols; i++)
+ for (size_t j = 0; j < neighbors.n_rows; j++)
+ neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+
+ // Finished with temporary matrix.
+ delete neighborPtr;
}
+ }
+} // Search()
- // Finished with temporary matrices.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+ TreeType* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ Timer::Start("computing_neighbors");
+
+ // Get a reference to the query set.
+ const typename TreeType::Mat& querySet = queryTree->Dataset();
+
+ // Make sure we are in dual-tree mode.
+ if (singleMode || naive)
+ throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
+ "query tree when naive or singleMode are set to true");
+
+ // We won't need to map query indices, but will we need to map distances?
+ arma::Mat<size_t>* neighborPtr = &neighbors;
+
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ neighborPtr = new arma::Mat<size_t>;
+
+ neighborPtr->set_size(k, querySet.n_cols);
+ neighborPtr->fill(size_t() - 1);
+ distances.set_size(k, querySet.n_cols);
+ distances.fill(SortPolicy::WorstDistance());
+
+ // Create the helper object for the traversal.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
+
+ // Create the traverser.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ traverser.Traverse(*queryTree, *referenceTree);
+
+ Timer::Stop("computing_neighbors");
+
+ // Do we need to map indices?
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ {
+ // We must map reference indices only.
+ neighbors.set_size(k, querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < neighbors.n_cols; i++)
+ for (size_t j = 0; j < neighbors.n_rows; j++)
+ neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+
+ // Finished with temporary matrix.
delete neighborPtr;
- delete distancePtr;
}
- else if (treeOwner && hasQuerySet && singleMode) // Map only references.
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ Timer::Start("computing_neighbors");
+
+ arma::Mat<size_t>* neighborPtr = &neighbors;
+ arma::mat* distancePtr = &distances;
+
+ if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
{
- // Set size of neighbor indices matrix correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
+ // We will always need to rearrange in this case.
+ distancePtr = new arma::mat;
+ neighborPtr = new arma::Mat<size_t>;
+ }
- // Map indices of neighbors.
- for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+ // Initialize results.
+ neighborPtr->set_size(k, referenceSet.n_cols);
+ neighborPtr->fill(size_t() - 1);
+ distancePtr->set_size(k, referenceSet.n_cols);
+ distancePtr->fill(SortPolicy::WorstDistance());
+
+ // Create the helper object for the traversal.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+ metric, true /* don't return the same point as nearest neighbor */);
+
+ if (naive)
+ {
+ // The naive brute-force solution.
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ rules.BaseCase(i, j);
+
+ baseCases += referenceSet.n_cols * referenceSet.n_cols;
+ }
+ else if (singleMode)
+ {
+ // Create the traverser.
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ Log::Info << rules.Scores() << " node combinations were scored.\n";
+ Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+ }
+ else
+ {
+ // Create the traverser.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ Log::Info << rules.Scores() << " node combinations were scored.\n";
+ Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+ }
+
+ Timer::Stop("computing_neighbors");
+
+ // Do we need to map the reference indices?
+ if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ {
+ neighbors.set_size(k, referenceSet.n_cols);
+ distances.set_size(k, referenceSet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; ++i)
{
- for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
- {
- resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
- }
+ // Map distances (copy a column).
+ const size_t refMapping = oldFromNewReferences[i];
+ distances.col(refMapping) = distancePtr->col(i);
+
+ // Map each neighbor's index.
+ for (size_t j = 0; j < distances.n_rows; ++j)
+ neighbors(j, refMapping) = oldFromNewReferences[(*neighborPtr)(j, i)];
}
- // Finished with temporary matrix.
+ // Finished with temporary matrices.
delete neighborPtr;
+ delete distancePtr;
}
-} // Search
-
+}
-//Return a String of the Object.
+// Return a String of the Object.
template<typename SortPolicy, typename MetricType, typename TreeType>
std::string NeighborSearch<SortPolicy, MetricType, TreeType>::ToString() const
{
std::ostringstream convert;
convert << "NeighborSearch [" << this << "]" << std::endl;
- convert << " Reference Set: " << referenceSet.n_rows << "x" ;
- convert << referenceSet.n_cols << std::endl;
- if (&referenceSet != &querySet)
- convert << " QuerySet: " << querySet.n_rows << "x" << querySet.n_cols
- << std::endl;
- convert << " Reference Tree: " << referenceTree << std::endl;
- if (&referenceTree != &queryTree)
- convert << " QueryTree: " << queryTree << std::endl;
- convert << " Tree Owner: " << treeOwner << std::endl;
+ convert << " Reference set: " << referenceSet.n_rows << "x" ;
+ convert << referenceSet.n_cols << std::endl;
+ if (referenceTree)
+ convert << " Reference tree: " << referenceTree << std::endl;
+ convert << " Tree owner: " << treeOwner << std::endl;
convert << " Naive: " << naive << std::endl;
convert << " Metric: " << std::endl;
convert << mlpack::util::Indent(metric.ToString(),2);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index e09df0c..0c27690 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -21,7 +21,8 @@ class NeighborSearchRules
const typename TreeType::Mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
- MetricType& metric);
+ MetricType& metric,
+ const bool sameSet = false);
/**
* Get the distance from the query point to the reference point.
* This will update the "neighbor" matrix with the new point if appropriate
@@ -116,6 +117,9 @@ class NeighborSearchRules
//! The instantiated metric.
MetricType& metric;
+ //! Denotes whether or not the reference and query sets are the same.
+ bool sameSet;
+
//! The last query point BaseCase() was called with.
size_t lastQueryIndex;
//! The last reference point BaseCase() was called with.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 131c132..50acfc2 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -19,12 +19,14 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
const typename TreeType::Mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
- MetricType& metric) :
+ MetricType& metric,
+ const bool sameSet) :
referenceSet(referenceSet),
querySet(querySet),
neighbors(neighbors),
distances(distances),
metric(metric),
+ sameSet(sameSet),
lastQueryIndex(querySet.n_cols),
lastReferenceIndex(referenceSet.n_cols),
baseCases(0),
@@ -44,7 +46,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ if (sameSet && (queryIndex == referenceIndex))
return 0.0;
// If we have already performed this base case, then do not perform it again.
More information about the mlpack-git
mailing list