[mlpack-git] master: Use enum type to define the different search modes (Closes #750). (6666d66)

gitdub at mlpack.org gitdub at mlpack.org
Sat Aug 20 14:56:06 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e

>---------------------------------------------------------------

commit 6666d667ab29c2a94cc92574a1f034cde586259f
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Aug 16 00:49:45 2016 -0300

    Use enum type to define the different search modes (Closes #750).


>---------------------------------------------------------------

6666d667ab29c2a94cc92574a1f034cde586259f
 .../methods/neighbor_search/neighbor_search.hpp    | 114 +++++
 .../neighbor_search/neighbor_search_impl.hpp       | 472 ++++++++++++++++-----
 2 files changed, 471 insertions(+), 115 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 1c2cace..ec088e5 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -73,6 +73,100 @@ class NeighborSearch
   //! Convenience typedef.
   typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
 
+  //! SearchMode represents the different neighbor search modes available.
+  enum SearchMode
+  {
+    NAIVE_MODE,
+    SINGLE_TREE_MODE,
+    DUAL_TREE_MODE
+  };
+
+  /**
+   * Initialize the NeighborSearch object, passing a reference dataset (this is
+   * the dataset which is searched).  Optionally, perform the computation in
+   * a different mode.  An initialized distance metric can be given, for cases
+   * where the metric has internal data (i.e. the distance::MahalanobisDistance
+   * class).
+   *
+   * This method will copy the matrices to internal copies, which are rearranged
+   * during tree-building.  You can avoid this extra copy by pre-constructing
+   * the trees and passing them using a different constructor, or by using the
+   * construct that takes an rvalue reference to the dataset.
+   *
+   * @param referenceSet Set of reference points.
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric An optional instance of the MetricType class.
+   */
+  NeighborSearch(const MatType& referenceSet,
+                 const SearchMode mode = DUAL_TREE_MODE,
+                 const double epsilon = 0,
+                 const MetricType metric = MetricType());
+
+  /**
+   * Initialize the NeighborSearch object, taking ownership of the reference
+   * dataset (this is the dataset which is searched).  Optionally, perform the
+   * computation in a different mode.  An initialized distance metric can be
+   * given, for cases where the metric has internal data (i.e. the
+   * distance::MahalanobisDistance class).
+   *
+   * This method will not copy the data matrix, but will take ownership of it,
+   * and depending on the type of tree used, may rearrange the points.  If you
+   * would rather a copy be made, consider using the constructor that takes a
+   * const reference to the data instead.
+   *
+   * @param referenceSet Set of reference points.
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric An optional instance of the MetricType class.
+   */
+  NeighborSearch(MatType&& referenceSet,
+                 const SearchMode mode = DUAL_TREE_MODE,
+                 const double epsilon = 0,
+                 const MetricType metric = MetricType());
+
+  /**
+   * Initialize the NeighborSearch object with the given pre-constructed
+   * reference tree (this is the tree built on the points that will be
+   * searched).  Optionally, perform the computation in a different mode.
+   * Naive mode is not available as an option for this constructor.
+   * Additionally, an instantiated distance metric can be given, for cases where
+   * the distance metric holds data.
+   *
+   * There is no copying of the data matrices in this constructor (because
+   * tree-building is not necessary), so this is the constructor to use when
+   * copies absolutely must be avoided.
+   *
+   * @note
+   * Mapping the points of the matrix back to their original indices is not done
+   * when this constructor is used, so if the tree type you are using maps
+   * points (like BinarySpaceTree), then you will have to perform the re-mapping
+   * manually.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric Instantiated distance metric.
+   */
+  NeighborSearch(Tree* referenceTree,
+                 const SearchMode mode = DUAL_TREE_MODE,
+                 const double epsilon = 0,
+                 const MetricType metric = MetricType());
+
+  /**
+   * Create a NeighborSearch object without any reference data.  If Search() is
+   * called before a reference set is set with Train(), an exception will be
+   * thrown.
+   *
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric Instantiated metric.
+   */
+  NeighborSearch(const SearchMode mode = DUAL_TREE_MODE,
+                 const double epsilon = 0,
+                 const MetricType metric = MetricType());
+
   /**
    * Initialize the NeighborSearch object, passing a reference dataset (this is
    * the dataset which is searched).  Optionally, perform the computation in
@@ -80,6 +174,8 @@ class NeighborSearch
    * given, for cases where the metric has internal data (i.e. the
    * distance::MahalanobisDistance class).
    *
+   * Deprecated. Will be removed in mlpack 3.0.0.
+   *
    * This method will copy the matrices to internal copies, which are rearranged
    * during tree-building.  You can avoid this extra copy by pre-constructing
    * the trees and passing them using a different constructor, or by using the
@@ -106,6 +202,8 @@ class NeighborSearch
    * metric can be given, for cases where the metric has internal data (i.e. the
    * distance::MahalanobisDistance class).
    *
+   * Deprecated. Will be removed in mlpack 3.0.0.
+   *
    * This method will not copy the data matrix, but will take ownership of it,
    * and depending on the type of tree used, may rearrange the points.  If you
    * would rather a copy be made, consider using the constructor that takes a
@@ -133,6 +231,8 @@ class NeighborSearch
    * distance metric can be given, for cases where the distance metric holds
    * data.
    *
+   * Deprecated. Will be removed in mlpack 3.0.0.
+   *
    * There is no copying of the data matrices in this constructor (because
    * tree-building is not necessary), so this is the constructor to use when
    * copies absolutely must be avoided.
@@ -161,6 +261,8 @@ class NeighborSearch
    * called before a reference set is set with Train(), an exception will be
    * thrown.
    *
+   * Deprecated. Will be removed in mlpack 3.0.0.
+   *
    * @param naive Whether to use naive search.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
@@ -309,14 +411,19 @@ class NeighborSearch
   //! Return the number of node combination scores during the last search.
   size_t Scores() const { return scores; }
 
+  //! Modify the search mode.
+  void SetSearchMode(const SearchMode mode);
+
   //! Access whether or not search is done in naive linear scan mode.
   bool Naive() const { return naive; }
   //! Modify whether or not search is done in naive linear scan mode.
+  //! Deprecated. Will be removed in mlpack 3.0.0.
   bool& Naive() { return naive; }
 
   //! Access whether or not search is done in single-tree mode.
   bool SingleMode() const { return singleMode; }
   //! Modify whether or not search is done in single-tree mode.
+  //! Deprecated. Will be removed in mlpack 3.0.0.
   bool& SingleMode() { return singleMode; }
 
   //! Access the relative error to be considered in approximate search.
@@ -344,6 +451,8 @@ class NeighborSearch
   //! If true, we own the reference set.
   bool setOwner;
 
+  //! Indicates the neighbor search mode.
+  SearchMode searchMode;
   //! Indicates if O(n^2) naive search is being used.
   bool naive;
   //! Indicates if single-tree search is being used (as opposed to dual-tree).
@@ -363,6 +472,11 @@ class NeighborSearch
   //! Search() without a query set.
   bool treeNeedsReset;
 
+  //! Updates searchMode to be according to naive and singleMode booleans.
+  //! This is only necessary until the modifiers Naive() and SingleMode() are
+  //! removed in mlpack 3.0.0.
+  void UpdateSearchMode();
+
   //! The NSModel class should have access to internal members.
   template<typename SortPol>
   friend class TrainVisitor;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index d95a992..f769de1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -75,6 +75,143 @@ template<typename SortPolicy,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
 SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
+                                         const SearchMode mode,
+                                         const double epsilon,
+                                         const MetricType metric) :
+    referenceTree(mode == NAIVE_MODE ? NULL :
+        BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
+    referenceSet(mode == NAIVE_MODE ? &referenceSetIn :
+        &referenceTree->Dataset()),
+    treeOwner(mode != NAIVE_MODE),
+    setOwner(false),
+    searchMode(mode),
+    naive(mode == NAIVE_MODE),
+    singleMode(mode == SINGLE_TREE_MODE),
+    epsilon(epsilon),
+    metric(metric),
+    baseCases(0),
+    scores(0),
+    treeNeedsReset(false)
+{
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
+                                         const SearchMode mode,
+                                         const double epsilon,
+                                         const MetricType metric) :
+    referenceTree(mode == NAIVE_MODE ? NULL :
+        BuildTree<MatType, Tree>(std::move(referenceSetIn),
+                                 oldFromNewReferences)),
+    referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) :
+        &referenceTree->Dataset()),
+    treeOwner(mode != NAIVE_MODE),
+    setOwner(mode == NAIVE_MODE),
+    searchMode(mode),
+    naive(mode == NAIVE_MODE),
+    singleMode(mode == SINGLE_TREE_MODE),
+    epsilon(epsilon),
+    metric(metric),
+    baseCases(0),
+    scores(0),
+    treeNeedsReset(false)
+{
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
+                                         const SearchMode mode,
+                                         const double epsilon,
+                                         const MetricType metric) :
+    referenceTree(referenceTree),
+    referenceSet(&referenceTree->Dataset()),
+    treeOwner(false),
+    setOwner(false),
+    searchMode(mode),
+    naive(mode == NAIVE_MODE),
+    singleMode(mode == SINGLE_TREE_MODE),
+    epsilon(epsilon),
+    metric(metric),
+    baseCases(0),
+    scores(0),
+    treeNeedsReset(false)
+{
+  if (mode == NAIVE_MODE)
+    throw std::invalid_argument("invalid constructor for naive mode");
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object without a reference dataset.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(const SearchMode mode,
+                                         const double epsilon,
+                                         const MetricType metric) :
+    referenceTree(NULL),
+    referenceSet(new MatType()), // Empty matrix.
+    treeOwner(false),
+    setOwner(true),
+    searchMode(mode),
+    naive(mode == NAIVE_MODE),
+    singleMode(mode == SINGLE_TREE_MODE),
+    epsilon(epsilon),
+    metric(metric),
+    baseCases(0),
+    scores(0),
+    treeNeedsReset(false)
+{
+  if (epsilon < 0)
+    throw std::invalid_argument("epsilon must be non-negative");
+  // Build the tree on the empty dataset, if necessary.
+  if (mode != NAIVE_MODE)
+  {
+    referenceTree = BuildTree<MatType, Tree>(*referenceSet,
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+}
+
+// Construct the object.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
                                          const bool naive,
                                          const bool singleMode,
                                          const double epsilon,
@@ -92,6 +229,10 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
     scores(0),
     treeNeedsReset(false)
 {
+  // Update Search Mode according to naive and singleMode flags.
+  searchMode = NAIVE_MODE;
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -126,6 +267,10 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
     scores(0),
     treeNeedsReset(false)
 {
+  // Update Search Mode according to naive and singleMode flags.
+  searchMode = NAIVE_MODE;
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -156,6 +301,10 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     scores(0),
     treeNeedsReset(false)
 {
+  // Update Search Mode according to naive and singleMode flags.
+  searchMode = NAIVE_MODE;
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -186,6 +335,10 @@ SingleTreeTraversalType>::NeighborSearch(const bool naive,
     scores(0),
     treeNeedsReset(false)
 {
+  // Update Search Mode according to naive and singleMode flags.
+  searchMode = NAIVE_MODE;
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
   // Build the tree on the empty dataset, if necessary.
@@ -227,12 +380,15 @@ void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(
     const MatType& referenceSet)
 {
+  // Update Search Mode.
+  UpdateSearchMode();
+
   // Clean up the old tree, if we built one.
   if (treeOwner && referenceTree)
     delete referenceTree;
 
   // We may need to rebuild the tree.
-  if (!naive)
+  if (searchMode != NAIVE_MODE)
   {
     referenceTree = BuildTree<MatType, Tree>(referenceSet,
         oldFromNewReferences);
@@ -247,7 +403,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(
   if (setOwner && this->referenceSet)
     delete this->referenceSet;
 
-  if (!naive)
+  if (searchMode != NAIVE_MODE)
     this->referenceSet = &referenceTree->Dataset();
   else
     this->referenceSet = &referenceSet;
@@ -265,12 +421,15 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
 {
+  // Update Search Mode.
+  UpdateSearchMode();
+
   // Clean up the old tree, if we built one.
   if (treeOwner && referenceTree)
     delete referenceTree;
 
   // We may need to rebuild the tree.
-  if (!naive)
+  if (searchMode != NAIVE_MODE)
   {
     referenceTree = BuildTree<MatType, Tree>(std::move(referenceSetIn),
         oldFromNewReferences);
@@ -285,7 +444,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
   if (setOwner && referenceSet)
     delete referenceSet;
 
-  if (!naive)
+  if (searchMode != NAIVE_MODE)
   {
     referenceSet = &referenceTree->Dataset();
     setOwner = false;
@@ -308,7 +467,10 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
 {
-  if (naive)
+  // Update Search Mode.
+  UpdateSearchMode();
+
+  if (searchMode == NAIVE_MODE)
     throw std::invalid_argument("cannot train on given reference tree when "
         "naive search (without trees) is desired");
 
@@ -342,6 +504,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
+  // Update Search Mode.
+  UpdateSearchMode();
+
   if (k > referenceSet->n_cols)
   {
     std::stringstream ss;
@@ -368,7 +533,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
   // Mapping is only necessary if the tree rearranges points.
   if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
-    if (!singleMode && !naive)
+    if (searchMode == DUAL_TREE_MODE)
     {
       distancePtr = new arma::mat; // Query indices need to be mapped.
       neighborPtr = new arma::Mat<size_t>;
@@ -383,70 +548,76 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
 
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
 
-  if (naive)
+  switch(searchMode)
   {
-    // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+    case NAIVE_MODE:
+    {
+      // Create the helper object for the tree traversal.
+      RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
-    // The naive brute-force traversal.
-    for (size_t i = 0; i < querySet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet->n_cols; ++j)
-        rules.BaseCase(i, j);
+      // The naive brute-force traversal.
+      for (size_t i = 0; i < querySet.n_cols; ++i)
+        for (size_t j = 0; j < referenceSet->n_cols; ++j)
+          rules.BaseCase(i, j);
 
-    baseCases += querySet.n_cols * referenceSet->n_cols;
+      baseCases += querySet.n_cols * referenceSet->n_cols;
 
-    rules.GetResults(*neighborPtr, *distancePtr);
-  }
-  else if (singleMode)
-  {
-    // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+      rules.GetResults(*neighborPtr, *distancePtr);
+      break;
+    }
+    case SINGLE_TREE_MODE:
+    {
+      // Create the helper object for the tree traversal.
+      RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
-    // Create the traverser.
-    SingleTreeTraversalType<RuleType> traverser(rules);
+      // Create the traverser.
+      SingleTreeTraversalType<RuleType> traverser(rules);
 
-    // Now have it traverse for each point.
-    for (size_t i = 0; i < querySet.n_cols; ++i)
-      traverser.Traverse(i, *referenceTree);
+      // Now have it traverse for each point.
+      for (size_t i = 0; i < querySet.n_cols; ++i)
+        traverser.Traverse(i, *referenceTree);
 
-    scores += rules.Scores();
-    baseCases += rules.BaseCases();
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
 
-    Log::Info << rules.Scores() << " node combinations were scored."
-        << std::endl;
-    Log::Info << rules.BaseCases() << " base cases were calculated."
-        << std::endl;
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
 
-    rules.GetResults(*neighborPtr, *distancePtr);
-  }
-  else // Dual-tree recursion.
-  {
-    // Build the query tree.
-    Timer::Stop("computing_neighbors");
-    Timer::Start("tree_building");
-    Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
-    Timer::Stop("tree_building");
-    Timer::Start("computing_neighbors");
+      rules.GetResults(*neighborPtr, *distancePtr);
+      break;
+    }
+    case DUAL_TREE_MODE:
+    {
+      // Build the query tree.
+      Timer::Stop("computing_neighbors");
+      Timer::Start("tree_building");
+      Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
+      Timer::Stop("tree_building");
+      Timer::Start("computing_neighbors");
 
-    // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
+      // Create the helper object for the tree traversal.
+      RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
 
-    // Create the traverser.
-    DualTreeTraversalType<RuleType> traverser(rules);
+      // Create the traverser.
+      DualTreeTraversalType<RuleType> traverser(rules);
 
-    traverser.Traverse(*queryTree, *referenceTree);
+      traverser.Traverse(*queryTree, *referenceTree);
 
-    scores += rules.Scores();
-    baseCases += rules.BaseCases();
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
 
-    Log::Info << rules.Scores() << " node combinations were scored."
-        << std::endl;
-    Log::Info << rules.BaseCases() << " base cases were calculated."
-        << std::endl;
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
 
-    rules.GetResults(*neighborPtr, *distancePtr);
+      rules.GetResults(*neighborPtr, *distancePtr);
 
-    delete queryTree;
+      delete queryTree;
+      break;
+    }
   }
 
   Timer::Stop("computing_neighbors");
@@ -454,7 +625,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
   // Map points back to original indices, if necessary.
   if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
-    if (!singleMode && !naive && treeOwner)
+    if (searchMode == DUAL_TREE_MODE && treeOwner)
     {
       // We must map both query and reference indices.
       neighbors.set_size(k, querySet.n_cols);
@@ -477,7 +648,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
       delete neighborPtr;
       delete distancePtr;
     }
-    else if (!singleMode && !naive)
+    else if (searchMode == DUAL_TREE_MODE)
     {
       // We must map query indices only.
       neighbors.set_size(k, querySet.n_cols);
@@ -527,6 +698,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::mat& distances,
     bool sameSet)
 {
+  // Update Search Mode.
+  UpdateSearchMode();
+
   if (k > referenceSet->n_cols)
   {
     std::stringstream ss;
@@ -536,7 +710,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
   }
 
   // Make sure we are in dual-tree mode.
-  if (singleMode || naive)
+  if (searchMode != DUAL_TREE_MODE)
     throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
         "query tree when naive or singleMode are set to true");
 
@@ -608,6 +782,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
+  // Update Search Mode.
+  UpdateSearchMode();
+
   if (k > referenceSet->n_cols)
   {
     std::stringstream ss;
@@ -640,78 +817,87 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
   RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
       true /* don't return the same point as nearest neighbor */);
 
-  if (naive)
+  switch (searchMode)
   {
-    // The naive brute-force solution.
-    for (size_t i = 0; i < referenceSet->n_cols; ++i)
-      for (size_t j = 0; j < referenceSet->n_cols; ++j)
-        rules.BaseCase(i, j);
+    case NAIVE_MODE:
+    {
+      // The naive brute-force solution.
+      for (size_t i = 0; i < referenceSet->n_cols; ++i)
+        for (size_t j = 0; j < referenceSet->n_cols; ++j)
+          rules.BaseCase(i, j);
 
-    baseCases += referenceSet->n_cols * referenceSet->n_cols;
-  }
-  else if (singleMode)
-  {
-    // Create the traverser.
-    SingleTreeTraversalType<RuleType> traverser(rules);
+      baseCases += referenceSet->n_cols * referenceSet->n_cols;
+      break;
+    }
+    case SINGLE_TREE_MODE:
+    {
+      // Create the traverser.
+      SingleTreeTraversalType<RuleType> traverser(rules);
 
-    // Now have it traverse for each point.
-    for (size_t i = 0; i < referenceSet->n_cols; ++i)
-      traverser.Traverse(i, *referenceTree);
+      // Now have it traverse for each point.
+      for (size_t i = 0; i < referenceSet->n_cols; ++i)
+        traverser.Traverse(i, *referenceTree);
 
-    scores += rules.Scores();
-    baseCases += rules.BaseCases();
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
 
-    Log::Info << rules.Scores() << " node combinations were scored."
-        << std::endl;
-    Log::Info << rules.BaseCases() << " base cases were calculated."
-        << std::endl;
-  }
-  else
-  {
-    // The dual-tree monochromatic search case may require resetting the bounds
-    // in the tree.
-    if (treeNeedsReset)
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
+      break;
+    }
+    case DUAL_TREE_MODE:
     {
-      std::stack<Tree*> nodes;
-      nodes.push(referenceTree);
-      while (!nodes.empty())
+      // The dual-tree monochromatic search case may require resetting the
+      // bounds in the tree.
+      if (treeNeedsReset)
       {
-        Tree* node = nodes.top();
-        nodes.pop();
+        std::stack<Tree*> nodes;
+        nodes.push(referenceTree);
+        while (!nodes.empty())
+        {
+          Tree* node = nodes.top();
+          nodes.pop();
 
-        // Reset bounds of this node.
-        node->Stat().Reset();
+          // Reset bounds of this node.
+          node->Stat().Reset();
 
-        // Then add the children.
-        for (size_t i = 0; i < node->NumChildren(); ++i)
-          nodes.push(&node->Child(i));
+          // Then add the children.
+          for (size_t i = 0; i < node->NumChildren(); ++i)
+            nodes.push(&node->Child(i));
+        }
       }
-    }
 
-    // Create the traverser.
-    DualTreeTraversalType<RuleType> traverser(rules);
+      // Create the traverser.
+      DualTreeTraversalType<RuleType> traverser(rules);
+
+      if (tree::IsSpillTree<Tree>::value)
+      {
+        // For Dual Tree Search on SpillTree, the queryTree must be built with
+        // non overlapping (tau = 0).
+        Tree queryTree(*referenceSet);
+        traverser.Traverse(queryTree, *referenceTree);
+      }
+      else
+      {
+        traverser.Traverse(*referenceTree, *referenceTree);
+        // Next time we perform this search, we'll need to reset the tree.
+        treeNeedsReset = true;
+      }
+
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
+
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
 
-    if (tree::IsSpillTree<Tree>::value)
-    {
-      // For Dual Tree Search on SpillTree, the queryTree must be built with non
-      // overlapping (tau = 0).
-      Tree queryTree(*referenceSet);
-      traverser.Traverse(queryTree, *referenceTree);
-    }
-    else
-    {
-      traverser.Traverse(*referenceTree, *referenceTree);
       // Next time we perform this search, we'll need to reset the tree.
       treeNeedsReset = true;
+      break;
     }
-
-    scores += rules.Scores();
-    baseCases += rules.BaseCases();
-
-    Log::Info << rules.Scores() << " node combinations were scored."
-        << std::endl;
-    Log::Info << rules.BaseCases() << " base cases were calculated."
-        << std::endl;
   }
 
   rules.GetResults(*neighborPtr, *distancePtr);
@@ -827,14 +1013,18 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
 {
   using data::CreateNVP;
 
+  // Update Search Mode.
+  UpdateSearchMode();
+
   // Serialize preferences for search.
+  ar & CreateNVP(searchMode, "searchMode");
   ar & CreateNVP(naive, "naive");
   ar & CreateNVP(singleMode, "singleMode");
   ar & CreateNVP(treeNeedsReset, "treeNeedsReset");
 
   // If we are doing naive search, we serialize the dataset.  Otherwise we
   // serialize the tree.
-  if (naive)
+  if (searchMode == NAIVE_MODE)
   {
     // Delete the current reference set, if necessary and if we are loading.
     if (Archive::is_loading::value)
@@ -895,6 +1085,58 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
   }
 }
 
+//! Set the Search Mode.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::SetSearchMode(
+    const SearchMode mode)
+{
+  searchMode = mode;
+  switch (mode)
+  {
+    case NAIVE_MODE:
+      naive = true;
+      break;
+    case SINGLE_TREE_MODE:
+      naive = false;
+      singleMode = true;
+      break;
+    case DUAL_TREE_MODE:
+      naive = false;
+      singleMode = false;
+      break;
+  }
+}
+
+//! Updates searchMode to be according to naive and singleMode booleans.
+//! This is only necessary until the modifiers Naive() and SingleMode() are
+//! removed in mlpack 3.0.0.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchMode()
+{
+  if (naive)
+    searchMode = NAIVE_MODE;
+  else if (singleMode)
+    searchMode = SINGLE_TREE_MODE;
+  else
+    searchMode = DUAL_TREE_MODE;
+}
+
 } // namespace neighbor
 } // namespace mlpack
 




More information about the mlpack-git mailing list