[mlpack-git] master: Remove SetSearchMode() and add Greedy() flags. (c1f7237)

gitdub at mlpack.org gitdub at mlpack.org
Sat Aug 20 15:37:19 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e

>---------------------------------------------------------------

commit c1f723724aa17fd93d8ca47ef49f8ee450a30e59
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Sat Aug 20 16:16:47 2016 -0300

    Remove SetSearchMode() and add Greedy() flags.


>---------------------------------------------------------------

c1f723724aa17fd93d8ca47ef49f8ee450a30e59
 .../methods/neighbor_search/neighbor_search.hpp    | 24 ++++---
 .../neighbor_search/neighbor_search_impl.hpp       | 79 ++++++++++++++--------
 .../methods/neighbor_search/ns_model_impl.hpp      | 43 ++++++++++--
 3 files changed, 105 insertions(+), 41 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 6c799f6..bbb882c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -412,11 +412,6 @@ class NeighborSearch
   //! Return the number of node combination scores during the last search.
   size_t Scores() const { return scores; }
 
-  //! Access the search mode.
-  NeighborSearchMode SearchMode() const { return searchMode; }
-  //! Modify the search mode.
-  void SetSearchMode(const NeighborSearchMode mode);
-
   //! Access whether or not search is done in naive linear scan mode.
   bool Naive() const { return naive; }
   //! Modify whether or not search is done in naive linear scan mode.
@@ -429,6 +424,12 @@ class NeighborSearch
   //! Deprecated. Will be removed in mlpack 3.0.0.
   bool& SingleMode() { return singleMode; }
 
+  //! Access whether or not search is done in greedy mode.
+  bool Greedy() const { return greedy; }
+  //! Modify whether or not search is done in greedy mode.
+  //! Deprecated. Will be removed in mlpack 3.0.0.
+  bool& Greedy() { return greedy; }
+
   //! Access the relative error to be considered in approximate search.
   double Epsilon() const { return epsilon; }
   //! Modify the relative error to be considered in approximate search.
@@ -460,6 +461,8 @@ class NeighborSearch
   bool naive;
   //! Indicates if single-tree search is being used (as opposed to dual-tree).
   bool singleMode;
+  //! Indicates if greedy search is being used.
+  bool greedy;
   //! Indicates the relative error to be considered in approximate search.
   double epsilon;
 
@@ -475,11 +478,16 @@ class NeighborSearch
   //! Search() without a query set.
   bool treeNeedsReset;
 
-  //! Updates searchMode to be according to naive and singleMode booleans.
-  //! This is only necessary until the modifiers Naive() and SingleMode() are
-  //! removed in mlpack 3.0.0.
+  //! Updates searchMode to be according to naive, singleMode and greedy
+  //! booleans.  This is only necessary until the modifiers Naive(),
+  //! SingleMode() and Greedy() are removed in mlpack 3.0.0.
   void UpdateSearchMode();
 
+  //! Updates naive, singleMode and greedy flags according to searchMode.  This
+  //! is only necessary until the modifiers Naive(), SingleMode() and Greedy()
+  //! are removed in mlpack 3.0.0.
+  void UpdateSearchModeFlags();
+
   //! The NSModel class should have access to internal members.
   template<typename SortPol>
   friend class TrainVisitor;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 7c3de94..ac06418 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -84,13 +84,16 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
         &referenceTree->Dataset()),
     treeOwner(mode != NAIVE_MODE),
     setOwner(false),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  SetSearchMode(mode);
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -116,13 +119,16 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
         &referenceTree->Dataset()),
     treeOwner(mode != NAIVE_MODE),
     setOwner(mode == NAIVE_MODE),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  SetSearchMode(mode);
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -145,13 +151,16 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
     setOwner(false),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  SetSearchMode(mode);
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
+
   if (mode == NAIVE_MODE)
     throw std::invalid_argument("invalid constructor for naive mode");
   if (epsilon < 0)
@@ -175,15 +184,19 @@ SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
     referenceSet(new MatType()), // Empty matrix.
     treeOwner(false),
     setOwner(true),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  SetSearchMode(mode);
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
+
   // Build the tree on the empty dataset, if necessary.
   if (mode != NAIVE_MODE)
   {
@@ -215,14 +228,14 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
     setOwner(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update Search Mode according to naive and singleMode flags.
-  searchMode = NAIVE_MODE;
+  // Update searchMode according to naive, singleMode and greedy flags.
   UpdateSearchMode();
 
   if (epsilon < 0)
@@ -253,14 +266,14 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
     setOwner(naive),
     naive(naive),
     singleMode(!naive && singleMode),
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update Search Mode according to naive and singleMode flags.
-  searchMode = NAIVE_MODE;
+  // Update searchMode according to naive, singleMode and greedy flags.
   UpdateSearchMode();
 
   if (epsilon < 0)
@@ -287,14 +300,14 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     setOwner(false),
     naive(false),
     singleMode(singleMode),
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update Search Mode according to naive and singleMode flags.
-  searchMode = NAIVE_MODE;
+  // Update searchMode according to naive, singleMode and greedy flags.
   UpdateSearchMode();
 
   if (epsilon < 0)
@@ -321,18 +334,19 @@ SingleTreeTraversalType>::NeighborSearch(const bool naive,
     setOwner(true),
     naive(naive),
     singleMode(singleMode),
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update Search Mode according to naive and singleMode flags.
-  searchMode = NAIVE_MODE;
+  // Update searchMode according to naive, singleMode and greedy flags.
   UpdateSearchMode();
 
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
+
   // Build the tree on the empty dataset, if necessary.
   if (!naive)
   {
@@ -372,7 +386,7 @@ void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(
     const MatType& referenceSet)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   // Clean up the old tree, if we built one.
@@ -413,7 +427,7 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   // Clean up the old tree, if we built one.
@@ -459,7 +473,7 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   if (searchMode == NAIVE_MODE)
@@ -496,7 +510,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   if (k > referenceSet->n_cols)
@@ -713,7 +727,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::mat& distances,
     bool sameSet)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   if (k > referenceSet->n_cols)
@@ -797,7 +811,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   if (k > referenceSet->n_cols)
@@ -1046,7 +1060,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
 {
   using data::CreateNVP;
 
-  // Update Search Mode.
+  // Update searchMode.
   UpdateSearchMode();
 
   // Serialize preferences for search.
@@ -1118,7 +1132,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Serialize(
   }
 }
 
-//! Set the Search Mode.
+//! Updates naive, singleMode and greedy flags according to searchMode.  This is
+//! only necessary until the modifiers Naive(), SingleMode() and Greedy() are
+//! removed in mlpack 3.0.0.
 template<typename SortPolicy,
          typename MetricType,
          typename MatType,
@@ -1128,33 +1144,36 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
-DualTreeTraversalType, SingleTreeTraversalType>::SetSearchMode(
-    const NeighborSearchMode mode)
+DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchModeFlags()
 {
-  searchMode = mode;
-  switch (mode)
+  switch (searchMode)
   {
     case NAIVE_MODE:
       naive = true;
+      singleMode = false;
+      greedy = false;
       break;
     case SINGLE_TREE_MODE:
       naive = false;
       singleMode = true;
+      greedy = false;
       break;
     case DUAL_TREE_MODE:
       naive = false;
       singleMode = false;
+      greedy = false;
       break;
     case GREEDY_SINGLE_TREE_MODE:
       naive = false;
       singleMode = true;
+      greedy = true;
       break;
   }
 }
 
-//! Updates searchMode to be according to naive and singleMode booleans.
-//! This is only necessary until the modifiers Naive() and SingleMode() are
-//! removed in mlpack 3.0.0.
+//! Updates searchMode to be according to naive, singleMode and greedy booleans.
+//! This is only necessary until the modifiers Naive(), SingleMode() and
+//! Greedy() are removed in mlpack 3.0.0.
 template<typename SortPolicy,
          typename MetricType,
          typename MatType,
@@ -1168,9 +1187,11 @@ DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchMode()
 {
   if (naive)
     searchMode = NAIVE_MODE;
-  else if (singleMode && (searchMode != GREEDY_SINGLE_TREE_MODE))
+  else if (singleMode && greedy)
+    searchMode = GREEDY_SINGLE_TREE_MODE;
+  else if (singleMode)
     searchMode = SINGLE_TREE_MODE;
-  else if (!singleMode)
+  else
     searchMode = DUAL_TREE_MODE;
 }
 
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 50520db..1591730 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -213,8 +213,33 @@ template<typename NSType>
 void SetSearchModeVisitor::operator()(NSType* ns) const
 {
   if (ns)
-    return ns->SetSearchMode(searchMode);
-  throw std::runtime_error("no neighbor search model initialized");
+  {
+    switch (searchMode)
+    {
+      case NAIVE_MODE:
+        ns->Naive() = true;
+        ns->SingleMode() = false;
+        ns->Greedy() = false;
+        break;
+      case SINGLE_TREE_MODE:
+        ns->Naive() = false;
+        ns->SingleMode() = true;
+        ns->Greedy() = false;
+        break;
+      case DUAL_TREE_MODE:
+        ns->Naive() = false;
+        ns->SingleMode() = false;
+        ns->Greedy() = false;
+        break;
+      case GREEDY_SINGLE_TREE_MODE:
+        ns->Naive() = false;
+        ns->SingleMode() = true;
+        ns->Greedy() = true;
+        break;
+    }
+  }
+  else
+    throw std::runtime_error("no neighbor search model initialized");
 }
 
 //! Return the search mode.
@@ -222,8 +247,18 @@ template<typename NSType>
 NeighborSearchMode SearchModeVisitor::operator()(NSType* ns) const
 {
   if (ns)
-    return ns->SearchMode();
-  throw std::runtime_error("no neighbor search model initialized");
+  {
+    if (ns->Naive())
+      return NAIVE_MODE;
+    else if (ns->SingleMode() && ns->Greedy())
+      return GREEDY_SINGLE_TREE_MODE;
+    else if (ns->SingleMode())
+      return SINGLE_TREE_MODE;
+    else
+      return DUAL_TREE_MODE;
+  }
+  else
+    throw std::runtime_error("no neighbor search model initialized");
 }
 
 //! Expose the Epsilon method of the given NSType.




More information about the mlpack-git mailing list