[mlpack-git] master: Use enum type to define the different search modes (Closes #750). (6666d66)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:06 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit 6666d667ab29c2a94cc92574a1f034cde586259f
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Aug 16 00:49:45 2016 -0300
Use enum type to define the different search modes (Closes #750).
>---------------------------------------------------------------
6666d667ab29c2a94cc92574a1f034cde586259f
.../methods/neighbor_search/neighbor_search.hpp | 114 +++++
.../neighbor_search/neighbor_search_impl.hpp | 472 ++++++++++++++++-----
2 files changed, 471 insertions(+), 115 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 1c2cace..ec088e5 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -73,6 +73,100 @@ class NeighborSearch
//! Convenience typedef.
typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
+ //! SearchMode represents the different neighbor search modes available.
+ enum SearchMode
+ {
+ NAIVE_MODE,
+ SINGLE_TREE_MODE,
+ DUAL_TREE_MODE
+ };
+
+ /**
+ * Initialize the NeighborSearch object, passing a reference dataset (this is
+ * the dataset which is searched). Optionally, perform the computation in
+ * a different mode. An initialized distance metric can be given, for cases
+ * where the metric has internal data (i.e. the distance::MahalanobisDistance
+ * class).
+ *
+ * This method will copy the matrices to internal copies, which are rearranged
+ * during tree-building. You can avoid this extra copy by pre-constructing
+ * the trees and passing them using a different constructor, or by using the
+ * construct that takes an rvalue reference to the dataset.
+ *
+ * @param referenceSet Set of reference points.
+ * @param mode Neighbor search mode.
+ * @param epsilon Relative approximate error (non-negative).
+ * @param metric An optional instance of the MetricType class.
+ */
+ NeighborSearch(const MatType& referenceSet,
+ const SearchMode mode = DUAL_TREE_MODE,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the NeighborSearch object, taking ownership of the reference
+ * dataset (this is the dataset which is searched). Optionally, perform the
+ * computation in a different mode. An initialized distance metric can be
+ * given, for cases where the metric has internal data (i.e. the
+ * distance::MahalanobisDistance class).
+ *
+ * This method will not copy the data matrix, but will take ownership of it,
+ * and depending on the type of tree used, may rearrange the points. If you
+ * would rather a copy be made, consider using the constructor that takes a
+ * const reference to the data instead.
+ *
+ * @param referenceSet Set of reference points.
+ * @param mode Neighbor search mode.
+ * @param epsilon Relative approximate error (non-negative).
+ * @param metric An optional instance of the MetricType class.
+ */
+ NeighborSearch(MatType&& referenceSet,
+ const SearchMode mode = DUAL_TREE_MODE,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the NeighborSearch object with the given pre-constructed
+ * reference tree (this is the tree built on the points that will be
+ * searched). Optionally, perform the computation in a different mode.
+ * Naive mode is not available as an option for this constructor.
+ * Additionally, an instantiated distance metric can be given, for cases where
+ * the distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Mapping the points of the matrix back to their original indices is not done
+ * when this constructor is used, so if the tree type you are using maps
+ * points (like BinarySpaceTree), then you will have to perform the re-mapping
+ * manually.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param mode Neighbor search mode.
+ * @param epsilon Relative approximate error (non-negative).
+ * @param metric Instantiated distance metric.
+ */
+ NeighborSearch(Tree* referenceTree,
+ const SearchMode mode = DUAL_TREE_MODE,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
+
+ /**
+ * Create a NeighborSearch object without any reference data. If Search() is
+ * called before a reference set is set with Train(), an exception will be
+ * thrown.
+ *
+ * @param mode Neighbor search mode.
+ * @param epsilon Relative approximate error (non-negative).
+ * @param metric Instantiated metric.
+ */
+ NeighborSearch(const SearchMode mode = DUAL_TREE_MODE,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
+
/**
* Initialize the NeighborSearch object, passing a reference dataset (this is
* the dataset which is searched). Optionally, perform the computation in
@@ -80,6 +174,8 @@ class NeighborSearch
* given, for cases where the metric has internal data (i.e. the
* distance::MahalanobisDistance class).
*
+ * Deprecated. Will be removed in mlpack 3.0.0.
+ *
* This method will copy the matrices to internal copies, which are rearranged
* during tree-building. You can avoid this extra copy by pre-constructing
* the trees and passing them using a different constructor, or by using the
@@ -106,6 +202,8 @@ class NeighborSearch
* metric can be given, for cases where the metric has internal data (i.e. the
* distance::MahalanobisDistance class).
*
+ * Deprecated. Will be removed in mlpack 3.0.0.
+ *
* This method will not copy the data matrix, but will take ownership of it,
* and depending on the type of tree used, may rearrange the points. If you
* would rather a copy be made, consider using the constructor that takes a
@@ -133,6 +231,8 @@ class NeighborSearch
* distance metric can be given, for cases where the distance metric holds
* data.
*
+ * Deprecated. Will be removed in mlpack 3.0.0.
+ *
* There is no copying of the data matrices in this constructor (because
* tree-building is not necessary), so this is the constructor to use when
* copies absolutely must be avoided.
@@ -161,6 +261,8 @@ class NeighborSearch
* called before a reference set is set with Train(), an exception will be
* thrown.
*
+ * Deprecated. Will be removed in mlpack 3.0.0.
+ *
* @param naive Whether to use naive search.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
@@ -309,14 +411,19 @@ class NeighborSearch
//! Return the number of node combination scores during the last search.
size_t Scores() const { return scores; }
+ //! Modify the search mode.
+ void SetSearchMode(const SearchMode mode);
+
//! Access whether or not search is done in naive linear scan mode.
bool Naive() const { return naive; }
//! Modify whether or not search is done in naive linear scan mode.
+ //! Deprecated. Will be removed in mlpack 3.0.0.
bool& Naive() { return naive; }
//! Access whether or not search is done in single-tree mode.
bool SingleMode() const { return singleMode; }
//! Modify whether or not search is done in single-tree mode.
+ //! Deprecated. Will be removed in mlpack 3.0.0.
bool& SingleMode() { return singleMode; }
//! Access the relative error to be considered in approximate search.
@@ -344,6 +451,8 @@ class NeighborSearch
//! If true, we own the reference set.
bool setOwner;
+ //! Indicates the neighbor search mode.
+ SearchMode searchMode;
//! Indicates if O(n^2) naive search is being used.
bool naive;
//! Indicates if single-tree search is being used (as opposed to dual-tree).
@@ -363,6 +472,11 @@ class NeighborSearch
//! Search() without a query set.
bool treeNeedsReset;
+ //! Updates searchMode to be according to naive and singleMode booleans.
+ //! This is only necessary until the modifiers Naive() and SingleMode() are
+ //! removed in mlpack 3.0.0.
+ void UpdateSearchMode();
+
//! The NSModel class should have access to internal members.
template<typename SortPol>
friend class TrainVisitor;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index d95a992..f769de1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -75,6 +75,143 @@ template<typename SortPolicy,
template<typename> class SingleTreeTraversalType>
NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
+ const SearchMode mode,
+ const double epsilon,
+ const MetricType metric) :
+ referenceTree(mode == NAIVE_MODE ? NULL :
+ BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
+ referenceSet(mode == NAIVE_MODE ? &referenceSetIn :
+ &referenceTree->Dataset()),
+ treeOwner(mode != NAIVE_MODE),
+ setOwner(false),
+ searchMode(mode),
+ naive(mode == NAIVE_MODE),
+ singleMode(mode == SINGLE_TREE_MODE),
+ epsilon(epsilon),
+ metric(metric),
+ baseCases(0),
+ scores(0),
+ treeNeedsReset(false)
+{
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
+ const SearchMode mode,
+ const double epsilon,
+ const MetricType metric) :
+ referenceTree(mode == NAIVE_MODE ? NULL :
+ BuildTree<MatType, Tree>(std::move(referenceSetIn),
+ oldFromNewReferences)),
+ referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) :
+ &referenceTree->Dataset()),
+ treeOwner(mode != NAIVE_MODE),
+ setOwner(mode == NAIVE_MODE),
+ searchMode(mode),
+ naive(mode == NAIVE_MODE),
+ singleMode(mode == SINGLE_TREE_MODE),
+ epsilon(epsilon),
+ metric(metric),
+ baseCases(0),
+ scores(0),
+ treeNeedsReset(false)
+{
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
+ const SearchMode mode,
+ const double epsilon,
+ const MetricType metric) :
+ referenceTree(referenceTree),
+ referenceSet(&referenceTree->Dataset()),
+ treeOwner(false),
+ setOwner(false),
+ searchMode(mode),
+ naive(mode == NAIVE_MODE),
+ singleMode(mode == SINGLE_TREE_MODE),
+ epsilon(epsilon),
+ metric(metric),
+ baseCases(0),
+ scores(0),
+ treeNeedsReset(false)
+{
+ if (mode == NAIVE_MODE)
+ throw std::invalid_argument("invalid constructor for naive mode");
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object without a reference dataset.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(const SearchMode mode,
+ const double epsilon,
+ const MetricType metric) :
+ referenceTree(NULL),
+ referenceSet(new MatType()), // Empty matrix.
+ treeOwner(false),
+ setOwner(true),
+ searchMode(mode),
+ naive(mode == NAIVE_MODE),
+ singleMode(mode == SINGLE_TREE_MODE),
+ epsilon(epsilon),
+ metric(metric),
+ baseCases(0),
+ scores(0),
+ treeNeedsReset(false)
+{
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
+ // Build the tree on the empty dataset, if necessary.
+ if (mode != NAIVE_MODE)
+ {
+ referenceTree = BuildTree<MatType, Tree>(*referenceSet,
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+}
+
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
const double epsilon,
@@ -92,6 +229,10 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
scores(0),
treeNeedsReset(false)
{
+ // Update Search Mode according to naive and singleMode flags.
+ searchMode = NAIVE_MODE;
+ UpdateSearchMode();
+
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@ -126,6 +267,10 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
scores(0),
treeNeedsReset(false)
{
+ // Update Search Mode according to naive and singleMode flags.
+ searchMode = NAIVE_MODE;
+ UpdateSearchMode();
+
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@ -156,6 +301,10 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
scores(0),
treeNeedsReset(false)
{
+ // Update Search Mode according to naive and singleMode flags.
+ searchMode = NAIVE_MODE;
+ UpdateSearchMode();
+
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@ -186,6 +335,10 @@ SingleTreeTraversalType>::NeighborSearch(const bool naive,
scores(0),
treeNeedsReset(false)
{
+ // Update Search Mode according to naive and singleMode flags.
+ searchMode = NAIVE_MODE;
+ UpdateSearchMode();
+
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
// Build the tree on the empty dataset, if necessary.
@@ -227,12 +380,15 @@ void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
DualTreeTraversalType, SingleTreeTraversalType>::Train(
const MatType& referenceSet)
{
+ // Update Search Mode.
+ UpdateSearchMode();
+
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
delete referenceTree;
// We may need to rebuild the tree.
- if (!naive)
+ if (searchMode != NAIVE_MODE)
{
referenceTree = BuildTree<MatType, Tree>(referenceSet,
oldFromNewReferences);
@@ -247,7 +403,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(
if (setOwner && this->referenceSet)
delete this->referenceSet;
- if (!naive)
+ if (searchMode != NAIVE_MODE)
this->referenceSet = &referenceTree->Dataset();
else
this->referenceSet = &referenceSet;
@@ -265,12 +421,15 @@ template<typename SortPolicy,
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
{
+ // Update Search Mode.
+ UpdateSearchMode();
+
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
delete referenceTree;
// We may need to rebuild the tree.
- if (!naive)
+ if (searchMode != NAIVE_MODE)
{
referenceTree = BuildTree<MatType, Tree>(std::move(referenceSetIn),
oldFromNewReferences);
@@ -285,7 +444,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
if (setOwner && referenceSet)
delete referenceSet;
- if (!naive)
+ if (searchMode != NAIVE_MODE)
{
referenceSet = &referenceTree->Dataset();
setOwner = false;
@@ -308,7 +467,10 @@ template<typename SortPolicy,
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
{
- if (naive)
+ // Update Search Mode.
+ UpdateSearchMode();
+
+ if (searchMode == NAIVE_MODE)
throw std::invalid_argument("cannot train on given reference tree when "
"naive search (without trees) is desired");
@@ -342,6 +504,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
+ // Update Search Mode.
+ UpdateSearchMode();
+
if (k > referenceSet->n_cols)
{
std::stringstream ss;
@@ -368,7 +533,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
// Mapping is only necessary if the tree rearranges points.
if (tree::TreeTraits<Tree>::RearrangesDataset)
{
- if (!singleMode && !naive)
+ if (searchMode == DUAL_TREE_MODE)
{
distancePtr = new arma::mat; // Query indices need to be mapped.
neighborPtr = new arma::Mat<size_t>;
@@ -383,70 +548,76 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- if (naive)
+ switch(searchMode)
{
- // Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+ case NAIVE_MODE:
+ {
+ // Create the helper object for the tree traversal.
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon);
- // The naive brute-force traversal.
- for (size_t i = 0; i < querySet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet->n_cols; ++j)
- rules.BaseCase(i, j);
+ // The naive brute-force traversal.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
+ rules.BaseCase(i, j);
- baseCases += querySet.n_cols * referenceSet->n_cols;
+ baseCases += querySet.n_cols * referenceSet->n_cols;
- rules.GetResults(*neighborPtr, *distancePtr);
- }
- else if (singleMode)
- {
- // Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+ rules.GetResults(*neighborPtr, *distancePtr);
+ break;
+ }
+ case SINGLE_TREE_MODE:
+ {
+ // Create the helper object for the tree traversal.
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon);
- // Create the traverser.
- SingleTreeTraversalType<RuleType> traverser(rules);
+ // Create the traverser.
+ SingleTreeTraversalType<RuleType> traverser(rules);
- // Now have it traverse for each point.
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
- scores += rules.Scores();
- baseCases += rules.BaseCases();
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
- Log::Info << rules.Scores() << " node combinations were scored."
- << std::endl;
- Log::Info << rules.BaseCases() << " base cases were calculated."
- << std::endl;
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
- rules.GetResults(*neighborPtr, *distancePtr);
- }
- else // Dual-tree recursion.
- {
- // Build the query tree.
- Timer::Stop("computing_neighbors");
- Timer::Start("tree_building");
- Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
- Timer::Stop("tree_building");
- Timer::Start("computing_neighbors");
+ rules.GetResults(*neighborPtr, *distancePtr);
+ break;
+ }
+ case DUAL_TREE_MODE:
+ {
+ // Build the query tree.
+ Timer::Stop("computing_neighbors");
+ Timer::Start("tree_building");
+ Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
+ Timer::Stop("tree_building");
+ Timer::Start("computing_neighbors");
- // Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
+ // Create the helper object for the tree traversal.
+ RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
- // Create the traverser.
- DualTreeTraversalType<RuleType> traverser(rules);
+ // Create the traverser.
+ DualTreeTraversalType<RuleType> traverser(rules);
- traverser.Traverse(*queryTree, *referenceTree);
+ traverser.Traverse(*queryTree, *referenceTree);
- scores += rules.Scores();
- baseCases += rules.BaseCases();
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
- Log::Info << rules.Scores() << " node combinations were scored."
- << std::endl;
- Log::Info << rules.BaseCases() << " base cases were calculated."
- << std::endl;
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
- rules.GetResults(*neighborPtr, *distancePtr);
+ rules.GetResults(*neighborPtr, *distancePtr);
- delete queryTree;
+ delete queryTree;
+ break;
+ }
}
Timer::Stop("computing_neighbors");
@@ -454,7 +625,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
// Map points back to original indices, if necessary.
if (tree::TreeTraits<Tree>::RearrangesDataset)
{
- if (!singleMode && !naive && treeOwner)
+ if (searchMode == DUAL_TREE_MODE && treeOwner)
{
// We must map both query and reference indices.
neighbors.set_size(k, querySet.n_cols);
@@ -477,7 +648,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
delete neighborPtr;
delete distancePtr;
}
- else if (!singleMode && !naive)
+ else if (searchMode == DUAL_TREE_MODE)
{
// We must map query indices only.
neighbors.set_size(k, querySet.n_cols);
@@ -527,6 +698,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
arma::mat& distances,
bool sameSet)
{
+ // Update Search Mode.
+ UpdateSearchMode();
+
if (k > referenceSet->n_cols)
{
std::stringstream ss;
@@ -536,7 +710,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
}
// Make sure we are in dual-tree mode.
- if (singleMode || naive)
+ if (searchMode != DUAL_TREE_MODE)
throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
"query tree when naive or singleMode are set to true");
@@ -608,6 +782,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
+ // Update Search Mode.
+ UpdateSearchMode();
+
if (k > referenceSet->n_cols)
{
std::stringstream ss;
@@ -640,78 +817,87 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
true /* don't return the same point as nearest neighbor */);
- if (naive)
+ switch (searchMode)
{
- // The naive brute-force solution.
- for (size_t i = 0; i < referenceSet->n_cols; ++i)
- for (size_t j = 0; j < referenceSet->n_cols; ++j)
- rules.BaseCase(i, j);
+ case NAIVE_MODE:
+ {
+ // The naive brute-force solution.
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
+ rules.BaseCase(i, j);
- baseCases += referenceSet->n_cols * referenceSet->n_cols;
- }
- else if (singleMode)
- {
- // Create the traverser.
- SingleTreeTraversalType<RuleType> traverser(rules);
+ baseCases += referenceSet->n_cols * referenceSet->n_cols;
+ break;
+ }
+ case SINGLE_TREE_MODE:
+ {
+ // Create the traverser.
+ SingleTreeTraversalType<RuleType> traverser(rules);
- // Now have it traverse for each point.
- for (size_t i = 0; i < referenceSet->n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
- scores += rules.Scores();
- baseCases += rules.BaseCases();
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
- Log::Info << rules.Scores() << " node combinations were scored."
- << std::endl;
- Log::Info << rules.BaseCases() << " base cases were calculated."
- << std::endl;
- }
- else
- {
- // The dual-tree monochromatic search case may require resetting the bounds
- // in the tree.
- if (treeNeedsReset)
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
+ break;
+ }
+ case DUAL_TREE_MODE:
{
- std::stack<Tree*> nodes;
- nodes.push(referenceTree);
- while (!nodes.empty())
+ // The dual-tree monochromatic search case may require resetting the
+ // bounds in the tree.
+ if (treeNeedsReset)
{
- Tree* node = nodes.top();
- nodes.pop();
+ std::stack<Tree*> nodes;
+ nodes.push(referenceTree);
+ while (!nodes.empty())
+ {
+ Tree* node = nodes.top();
+ nodes.pop();
- // Reset bounds of this node.
- node->Stat().Reset();
+ // Reset bounds of this node.
+ node->Stat().Reset();
- // Then add the children.
- for (size_t i = 0; i < node->NumChildren(); ++i)
- nodes.push(&node->Child(i));
+ // Then add the children.
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ nodes.push(&node->Child(i));
+ }
}
- }
- // Create the traverser.
- DualTreeTraversalType<RuleType> traverser(rules);
+ // Create the traverser.
+ DualTreeTraversalType<RuleType> traverser(rules);
+
+ if (tree::IsSpillTree<Tree>::value)
+ {
+ // For Dual Tree Search on SpillTree, the queryTree must be built with
+ // non overlapping (tau = 0).
+ Tree queryTree(*referenceSet);
+ traverser.Traverse(queryTree, *referenceTree);
+ }
+ else
+ {
+ traverser.Traverse(*referenceTree, *referenceTree);
+ // Next time we perform this search, we'll need to reset the tree.
+ treeNeedsReset = true;
+ }
+
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
+
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
- if (tree::IsSpillTree<Tree>::value)
- {
- // For Dual Tree Search on SpillTree, the queryTree must be built with non
- // overlapping (tau = 0).
- Tree queryTree(*referenceSet);
- traverser.Traverse(queryTree, *referenceTree);
- }
- else
- {
- traverser.Traverse(*referenceTree, *referenceTree);
// Next time we perform this search, we'll need to reset the tree.
treeNeedsReset = true;
+ break;
}
-
- scores += rules.Scores();
- baseCases += rules.BaseCases();
-
- Log::Info << rules.Scores() << " node combinations were scored."
- << std::endl;
- Log::Info << rules.BaseCases() << " base cases were calculated."
- << std::endl;
}
rules.GetResults(*neighborPtr, *distancePtr);
@@ -827,14 +1013,18 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
{
using data::CreateNVP;
+ // Update Search Mode.
+ UpdateSearchMode();
+
// Serialize preferences for search.
+ ar & CreateNVP(searchMode, "searchMode");
ar & CreateNVP(naive, "naive");
ar & CreateNVP(singleMode, "singleMode");
ar & CreateNVP(treeNeedsReset, "treeNeedsReset");
// If we are doing naive search, we serialize the dataset. Otherwise we
// serialize the tree.
- if (naive)
+ if (searchMode == NAIVE_MODE)
{
// Delete the current reference set, if necessary and if we are loading.
if (Archive::is_loading::value)
@@ -895,6 +1085,58 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
}
}
+//! Set the Search Mode.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::SetSearchMode(
+ const SearchMode mode)
+{
+ searchMode = mode;
+ switch (mode)
+ {
+ case NAIVE_MODE:
+ naive = true;
+ break;
+ case SINGLE_TREE_MODE:
+ naive = false;
+ singleMode = true;
+ break;
+ case DUAL_TREE_MODE:
+ naive = false;
+ singleMode = false;
+ break;
+ }
+}
+
+//! Updates searchMode to be according to naive and singleMode booleans.
+//! This is only necessary until the modifiers Naive() and SingleMode() are
+//! removed in mlpack 3.0.0.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchMode()
+{
+ if (naive)
+ searchMode = NAIVE_MODE;
+ else if (singleMode)
+ searchMode = SINGLE_TREE_MODE;
+ else
+ searchMode = DUAL_TREE_MODE;
+}
+
} // namespace neighbor
} // namespace mlpack
More information about the mlpack-git
mailing list