[mlpack-git] master: Modify NSModel to use boost variant. (86a9852)

gitdub at mlpack.org gitdub at mlpack.org
Mon Jun 13 14:29:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a7e8d3bac60d6aa2717f27e0d4f6c53ff20607f5...779556fed748819a18cc898d9a6f69900740ef23

>---------------------------------------------------------------

commit 86a9852f19daf58e4ae5a3d3ca745deee43d8d16
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Mon Jun 13 10:31:58 2016 -0300

    Modify NSModel to use boost variant.


>---------------------------------------------------------------

86a9852f19daf58e4ae5a3d3ca745deee43d8d16
 .../methods/neighbor_search/neighbor_search.hpp    |   5 +-
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 153 ++++++-
 .../methods/neighbor_search/ns_model_impl.hpp      | 506 +++++++++------------
 3 files changed, 345 insertions(+), 319 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e84b74c..bf62ae7 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,8 +27,7 @@ namespace neighbor /** Neighbor-search routines.  These include
                     * searches. */ {
 
 // Forward declaration.
-template<typename SortPolicy>
-class NSModel;
+class TrainVisitor;
 
 /**
  * The NeighborSearch class is a template class for performing distance-based
@@ -308,7 +307,7 @@ class NeighborSearch
   bool treeNeedsReset;
 
   //! The NSModel class should have access to internal members.
-  friend class NSModel<SortPolicy>;
+  friend class TrainVisitor;
 }; // class NeighborSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 9c16199..458bf97 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -13,12 +13,24 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
-
+#include <boost/variant.hpp>
 #include "neighbor_search.hpp"
 
 namespace mlpack {
 namespace neighbor {
 
+template<typename SortPolicy,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+using NSType = NeighborSearch<SortPolicy,
+                              metric::EuclideanDistance,
+                              arma::mat,
+                              TreeType,
+                              TreeType<metric::EuclideanDistance,
+                                  NeighborSearchStat<SortPolicy>,
+                                  arma::mat>::template DualTreeTraverser>;
+
 template<typename SortPolicy>
 struct NSModelName
 {
@@ -37,6 +49,121 @@ struct NSModelName<FurthestNeighborSort>
   static const std::string Name() { return "furthest_neighbor_search_model"; }
 };
 
+class SearchKVisitor : public boost::static_visitor<void>
+{
+ private:
+   const size_t k;
+   arma::Mat<size_t>& neighbors;
+   arma::mat& distances;
+
+ public:
+   template<typename NSType>
+   void operator()(NSType *ns) const;
+
+   SearchKVisitor(const size_t k,
+                  arma::Mat<size_t>& neighbors,
+                  arma::mat& distances);
+};
+
+class SearchVisitor : public boost::static_visitor<void>
+{
+ private:
+   const arma::mat& querySet;
+   const size_t k;
+   arma::Mat<size_t>& neighbors;
+   arma::mat& distances;
+   const size_t leafSize;
+
+   template<typename NSType>
+   void SearchLeaf(NSType *ns) const;
+
+ public:
+   template<typename SortPolicy,
+            template<typename TreeMetricType,
+                     typename TreeStatType,
+                     typename TreeMatType> class TreeType>
+   void operator()(NSType<SortPolicy,TreeType> *ns) const;
+
+   template<typename SortPolicy>
+   void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+
+   template<typename SortPolicy>
+   void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+
+   SearchVisitor(const arma::mat& querySet,
+                 const size_t k,
+                 arma::Mat<size_t>& neighbors,
+                 arma::mat& distances,
+                 const size_t leafSize);
+};
+
+class TrainVisitor : public boost::static_visitor<void>
+{
+ private:
+   arma::mat&& referenceSet;
+   size_t leafSize;
+
+   template<typename NSType>
+   void TrainLeaf(NSType* ns) const;
+
+ public:
+   template<typename SortPolicy,
+            template<typename TreeMetricType,
+                     typename TreeStatType,
+                     typename TreeMatType> class TreeType>
+   void operator()(NSType<SortPolicy,TreeType> *ns) const;
+
+   template<typename SortPolicy>
+   void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+
+   template<typename SortPolicy>
+   void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+
+   TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+};
+
+class SingleModeVisitor : public boost::static_visitor<bool&>
+{
+ public:
+   template<typename NSType>
+   bool& operator()(NSType *ns) const;
+};
+
+class NaiveVisitor : public boost::static_visitor<bool&>
+{
+ public:
+   template<typename NSType>
+   bool& operator()(NSType *ns) const;
+};
+
+class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
+{
+ public:
+   template<typename NSType>
+   const arma::mat& operator()(NSType *ns) const;
+};
+
+class DeleteVisitor : public boost::static_visitor<void>
+{
+ public:
+   template<typename NSType>
+   void operator()(NSType *ns) const;
+};
+
+template<typename Archive>
+class SerializeVisitor : public boost::static_visitor<void>
+{
+ private:
+   Archive& ar;
+   const std::string& name;
+
+ public:
+   template<typename NSType>
+   void operator()(NSType *ns) const;
+
+   SerializeVisitor(Archive& ar, const std::string& name);
+};
+
 template<typename SortPolicy>
 class NSModel
 {
@@ -59,24 +186,12 @@ class NSModel
   bool randomBasis;
   arma::mat q;
 
-  template<template<typename TreeMetricType,
-                    typename TreeStatType,
-                    typename TreeMatType> class TreeType>
-  using NSType = NeighborSearch<SortPolicy,
-                                metric::EuclideanDistance,
-                                arma::mat,
-                                TreeType,
-                                TreeType<metric::EuclideanDistance,
-                                    NeighborSearchStat<SortPolicy>,
-                                    arma::mat>::template DualTreeTraverser>;
-
-  // Only one of these pointers will be non-NULL.
-  NSType<tree::KDTree>* kdTreeNS;
-  NSType<tree::StandardCoverTree>* coverTreeNS;
-  NSType<tree::RTree>* rTreeNS;
-  NSType<tree::RStarTree>* rStarTreeNS;
-  NSType<tree::BallTree>* ballTreeNS;
-  NSType<tree::XTree>* xTreeNS;
+  boost::variant<NSType<SortPolicy, tree::KDTree>*,
+                 NSType<SortPolicy, tree::StandardCoverTree>*,
+                 NSType<SortPolicy, tree::RTree>*,
+                 NSType<SortPolicy, tree::RStarTree>*,
+                 NSType<SortPolicy, tree::BallTree>*,
+                 NSType<SortPolicy, tree::XTree>*> nSearch;
 
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index e4aa7c1..ea2206e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -16,6 +16,190 @@
 namespace mlpack {
 namespace neighbor {
 
+SearchKVisitor::SearchKVisitor(const size_t k,
+                               arma::Mat<size_t>& neighbors,
+                               arma::mat& distances) :
+    k(k),
+    neighbors(neighbors),
+    distances(distances)
+{}
+
+template<typename NSType>
+void SearchKVisitor::operator()(NSType *ns) const
+{
+  if (ns)
+    return ns->Search(k, neighbors, distances);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+SearchVisitor::SearchVisitor(const arma::mat& querySet,
+                             const size_t k,
+                             arma::Mat<size_t>& neighbors,
+                             arma::mat& distances,
+                             const size_t leafSize) :
+    querySet(querySet),
+    k(k),
+    neighbors(neighbors),
+    distances(distances),
+    leafSize(leafSize)
+{}
+
+template<typename SortPolicy,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void SearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+{
+  if (ns)
+    return ns->Search(querySet, k, neighbors, distances);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void SearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
+{
+  if (ns)
+    return SearchLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void SearchVisitor::operator()(NSType<SortPolicy,tree::BallTree> *ns) const
+{
+  if (ns)
+    return SearchLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename NSType>
+void SearchVisitor::SearchLeaf(NSType *ns) const
+{
+  if (!ns->Naive() && !ns->SingleMode())
+  {
+    std::vector<size_t> oldFromNewQueries;
+    typename NSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
+        leafSize);
+
+    arma::Mat<size_t> neighborsOut;
+    arma::mat distancesOut;
+    ns->Search(&queryTree, k, neighborsOut, distancesOut);
+
+    // Unmap the query points.
+    distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+    neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+    for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+    {
+      neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+      distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+    }
+  }
+  else
+    ns->Search(querySet, k, neighbors, distances);
+}
+
+
+TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) :
+    referenceSet(std::move(referenceSet)),
+    leafSize(leafSize)
+{}
+
+template<typename SortPolicy,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+{
+  if (ns)
+    return ns->Train(std::move(referenceSet));
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
+{
+  if (ns)
+    return TrainLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void TrainVisitor::operator ()(NSType<SortPolicy,tree::BallTree> *ns) const
+{
+  if (ns)
+    return TrainLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename NSType>
+void TrainVisitor::TrainLeaf(NSType* ns) const
+{
+  if (ns->Naive())
+    ns->Train(std::move(referenceSet));
+  else
+  {
+    std::vector<size_t> oldFromNewReferences;
+    typename NSType::Tree* tree =
+        new typename NSType::Tree(std::move(referenceSet),
+        oldFromNewReferences, leafSize);
+    ns->Train(tree);
+
+    // Give the model ownership of the tree and the mappings.
+    ns->treeOwner = true;
+    ns->oldFromNewReferences = std::move(oldFromNewReferences);
+  }
+}
+
+
+template<typename NSType>
+bool& SingleModeVisitor::operator()(NSType *ns) const
+{
+  if (ns)
+    return ns->SingleMode();
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+bool& NaiveVisitor::operator()(NSType *ns) const
+{
+  if (ns)
+    return ns->Naive();
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+const arma::mat& ReferenceSetVisitor::operator()(NSType *ns) const
+{
+  if (ns)
+    return ns->ReferenceSet();
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+void DeleteVisitor::operator()(NSType *ns) const
+{
+  if (ns)
+    delete ns;
+}
+
+
+template<typename Archive>
+SerializeVisitor<Archive>::SerializeVisitor(Archive& ar,
+                                            const std::string& name) :
+    ar(ar),
+    name(name)
+{}
+
+template<typename Archive>
+template<typename NSType>
+void SerializeVisitor<Archive>::operator()(NSType *ns) const
+{
+  ar & data::CreateNVP(ns, name);
+}
+
 /**
  * Initialize the NSModel with the given type and whether or not a random
  * basis should be used.
@@ -23,13 +207,7 @@ namespace neighbor {
 template<typename SortPolicy>
 NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
     treeType(treeType),
-    randomBasis(randomBasis),
-    kdTreeNS(NULL),
-    coverTreeNS(NULL),
-    rTreeNS(NULL),
-    rStarTreeNS(NULL),
-    ballTreeNS(NULL),
-    xTreeNS(NULL)
+    randomBasis(randomBasis)
 {
   // Nothing to do.
 }
@@ -38,18 +216,7 @@ NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
 template<typename SortPolicy>
 NSModel<SortPolicy>::~NSModel()
 {
-  if (kdTreeNS)
-    delete kdTreeNS;
-  if (coverTreeNS)
-    delete coverTreeNS;
-  if (rTreeNS)
-    delete rTreeNS;
-  if (rStarTreeNS)
-    delete rStarTreeNS;
-  if (ballTreeNS)
-    delete ballTreeNS;
-  if (xTreeNS)
-    delete xTreeNS;
+  boost::apply_visitor(DeleteVisitor(), nSearch);
 }
 
 //! Serialize the kNN model.
@@ -64,148 +231,43 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
 
   // This should never happen, but just in case, be clean with memory.
   if (Archive::is_loading::value)
-  {
-    if (kdTreeNS)
-      delete kdTreeNS;
-    if (coverTreeNS)
-      delete coverTreeNS;
-    if (rTreeNS)
-      delete rTreeNS;
-    if (rStarTreeNS)
-      delete rStarTreeNS;
-    if (ballTreeNS)
-      delete ballTreeNS;
-    if (xTreeNS)
-      delete xTreeNS;
-
-    // Set all the pointers to NULL.
-    kdTreeNS = NULL;
-    coverTreeNS = NULL;
-    rTreeNS = NULL;
-    rStarTreeNS = NULL;
-    ballTreeNS = NULL;
-    xTreeNS = NULL;
-  }
+    boost::apply_visitor(DeleteVisitor(), nSearch);
 
   // We'll only need to serialize one of the kNN objects, based on the type.
   const std::string& name = NSModelName<SortPolicy>::Name();
-  switch (treeType)
-  {
-    case KD_TREE:
-      ar & data::CreateNVP(kdTreeNS, name);
-      break;
-    case COVER_TREE:
-      ar & data::CreateNVP(coverTreeNS, name);
-      break;
-    case R_TREE:
-      ar & data::CreateNVP(rTreeNS, name);
-      break;
-    case R_STAR_TREE:
-      ar & data::CreateNVP(rStarTreeNS, name);
-      break;
-    case BALL_TREE:
-      ar & data::CreateNVP(ballTreeNS, name);
-      break;
-    case X_TREE:
-      ar & data::CreateNVP(xTreeNS, name);
-      break;
-  }
+  SerializeVisitor<Archive> s(ar, name);
+  boost::apply_visitor(s, nSearch);
 }
 
 template<typename SortPolicy>
 const arma::mat& NSModel<SortPolicy>::Dataset() const
 {
-  if (kdTreeNS)
-    return kdTreeNS->ReferenceSet();
-  else if (coverTreeNS)
-    return coverTreeNS->ReferenceSet();
-  else if (rTreeNS)
-    return rTreeNS->ReferenceSet();
-  else if (rStarTreeNS)
-    return rStarTreeNS->ReferenceSet();
-  else if (ballTreeNS)
-    return ballTreeNS->ReferenceSet();
-  else if (xTreeNS)
-    return xTreeNS->ReferenceSet();
-
-  throw std::runtime_error("no neighbor search model initialized");
+  return boost::apply_visitor(ReferenceSetVisitor(), nSearch);
 }
 
 //! Expose singleMode.
 template<typename SortPolicy>
 bool NSModel<SortPolicy>::SingleMode() const
 {
-  if (kdTreeNS)
-    return kdTreeNS->SingleMode();
-  else if (coverTreeNS)
-    return coverTreeNS->SingleMode();
-  else if (rTreeNS)
-    return rTreeNS->SingleMode();
-  else if (rStarTreeNS)
-    return rStarTreeNS->SingleMode();
-  else if (ballTreeNS)
-    return ballTreeNS->SingleMode();
-  else if (xTreeNS)
-    return xTreeNS->SingleMode();
-
-  throw std::runtime_error("no neighbor search model initialized");
+  return boost::apply_visitor(SingleModeVisitor(), nSearch);
 }
 
 template<typename SortPolicy>
 bool& NSModel<SortPolicy>::SingleMode()
 {
-  if (kdTreeNS)
-    return kdTreeNS->SingleMode();
-  else if (coverTreeNS)
-    return coverTreeNS->SingleMode();
-  else if (rTreeNS)
-    return rTreeNS->SingleMode();
-  else if (rStarTreeNS)
-    return rStarTreeNS->SingleMode();
-  else if (ballTreeNS)
-    return ballTreeNS->SingleMode();
-  else if (xTreeNS)
-    return xTreeNS->SingleMode();
-
-  throw std::runtime_error("no neighbor search model initialized");
+  return boost::apply_visitor(SingleModeVisitor(), nSearch);
 }
 
 template<typename SortPolicy>
 bool NSModel<SortPolicy>::Naive() const
 {
-  if (kdTreeNS)
-    return kdTreeNS->Naive();
-  else if (coverTreeNS)
-    return coverTreeNS->Naive();
-  else if (rTreeNS)
-    return rTreeNS->Naive();
-  else if (rStarTreeNS)
-    return rStarTreeNS->Naive();
-  else if (ballTreeNS)
-    return ballTreeNS->Naive();
-  else if (xTreeNS)
-    return xTreeNS->Naive();
-
-  throw std::runtime_error("no neighbor search model initialized");
+  return boost::apply_visitor(NaiveVisitor(), nSearch);
 }
 
 template<typename SortPolicy>
 bool& NSModel<SortPolicy>::Naive()
 {
-  if (kdTreeNS)
-    return kdTreeNS->Naive();
-  else if (coverTreeNS)
-    return coverTreeNS->Naive();
-  else if (rTreeNS)
-    return rTreeNS->Naive();
-  else if (rStarTreeNS)
-    return rStarTreeNS->Naive();
-  else if (ballTreeNS)
-    return ballTreeNS->Naive();
-  else if (xTreeNS)
-    return xTreeNS->Naive();
-
-  throw std::runtime_error("no neighbor search model initialized");
+  return boost::apply_visitor(NaiveVisitor(), nSearch);
 }
 
 //! Build the reference tree.
@@ -248,18 +310,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
   }
 
   // Clean memory, if necessary.
-  if (kdTreeNS)
-    delete kdTreeNS;
-  if (coverTreeNS)
-    delete coverTreeNS;
-  if (rTreeNS)
-    delete rTreeNS;
-  if (rStarTreeNS)
-    delete rStarTreeNS;
-  if (ballTreeNS)
-    delete ballTreeNS;
-  if (xTreeNS)
-    delete xTreeNS;
+  boost::apply_visitor(DeleteVisitor(), nSearch);
 
   // Do we need to modify the reference set?
   if (randomBasis)
@@ -274,69 +325,29 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
   switch (treeType)
   {
     case KD_TREE:
-      // If necessary, build the kd-tree.
-      if (naive)
-      {
-        kdTreeNS = new NSType<tree::KDTree>(std::move(referenceSet), naive,
-            singleMode);
-      }
-      else
-      {
-        std::vector<size_t> oldFromNewReferences;
-        typename NSType<tree::KDTree>::Tree* kdTree =
-            new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
-            oldFromNewReferences, leafSize);
-        kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
-
-        // Give the model ownership of the tree and the mappings.
-        kdTreeNS->treeOwner = true;
-        kdTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
-      }
-
+      nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode);
       break;
     case COVER_TREE:
-      // If necessary, build the cover tree.
-      coverTreeNS = new NSType<tree::StandardCoverTree>(std::move(referenceSet),
-          naive, singleMode);
+      nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(naive,
+          singleMode);
       break;
     case R_TREE:
-      // If necessary, build the R tree.
-      rTreeNS = new NSType<tree::RTree>(std::move(referenceSet), naive,
-          singleMode);
+      nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode);
       break;
     case R_STAR_TREE:
-      // If necessary, build the R* tree.
-      rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
-          singleMode);
+      nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode);
       break;
     case BALL_TREE:
-      // If necessary, build the ball tree.
-      if (naive)
-      {
-        ballTreeNS = new NSType<tree::BallTree>(std::move(referenceSet), naive,
-            singleMode);
-      }
-      else
-      {
-        std::vector<size_t> oldFromNewReferences;
-        typename NSType<tree::BallTree>::Tree* ballTree =
-            new typename NSType<tree::BallTree>::Tree(std::move(referenceSet),
-            oldFromNewReferences, leafSize);
-        ballTreeNS = new NSType<tree::BallTree>(ballTree, singleMode);
-
-        // Give the model ownership of the tree and the mappings.
-        ballTreeNS->treeOwner = true;
-        ballTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
-      }
-
+      nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode);
       break;
     case X_TREE:
-      // If necessary, build the X tree.
-      xTreeNS = new NSType<tree::XTree>(std::move(referenceSet), naive,
-          singleMode);
+      nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode);
       break;
   }
 
+  TrainVisitor tn(std::move(referenceSet),leafSize);
+  boost::apply_visitor(tn, nSearch);
+
   if (!naive)
   {
     Timer::Stop("tree_building");
@@ -363,88 +374,8 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
   else
     Log::Info << "brute-force (naive) search..." << std::endl;
 
-  switch (treeType)
-  {
-    case KD_TREE:
-      if (!kdTreeNS->Naive() && !kdTreeNS->SingleMode())
-      {
-        // Build a second tree and search.
-        Timer::Start("tree_building");
-        Log::Info << "Building query tree..." << std::endl;
-        std::vector<size_t> oldFromNewQueries;
-        typename NSType<tree::KDTree>::Tree queryTree(std::move(querySet),
-            oldFromNewQueries, leafSize);
-        Log::Info << "Tree built." << std::endl;
-        Timer::Stop("tree_building");
-
-        arma::Mat<size_t> neighborsOut;
-        arma::mat distancesOut;
-        kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
-
-        // Unmap the query points.
-        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
-        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
-        {
-          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
-          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
-        }
-      }
-      else
-      {
-        // Search without building a second tree.
-        kdTreeNS->Search(querySet, k, neighbors, distances);
-      }
-      break;
-    case COVER_TREE:
-      // No mapping necessary.
-      coverTreeNS->Search(querySet, k, neighbors, distances);
-      break;
-    case R_TREE:
-      // No mapping necessary.
-      rTreeNS->Search(querySet, k, neighbors, distances);
-      break;
-    case R_STAR_TREE:
-      // No mapping necessary.
-      rStarTreeNS->Search(querySet, k, neighbors, distances);
-      break;
-    case BALL_TREE:
-      if (!ballTreeNS->Naive() && !ballTreeNS->SingleMode())
-      {
-        // Build a second tree and search.
-        Timer::Start("tree_building");
-        Log::Info << "Building query tree..." << std::endl;
-        std::vector<size_t> oldFromNewQueries;
-        typename NSType<tree::BallTree>::Tree queryTree(std::move(querySet),
-            oldFromNewQueries, leafSize);
-        Log::Info << "Tree built." << std::endl;
-        Timer::Stop("tree_building");
-
-        arma::Mat<size_t> neighborsOut;
-        arma::mat distancesOut;
-        ballTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
-
-        // Unmap the query points.
-        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
-        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
-        {
-          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
-          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
-        }
-      }
-      else
-      {
-        // Search without building a second tree.
-        ballTreeNS->Search(querySet, k, neighbors, distances);
-      }
-
-      break;
-    case X_TREE:
-      // No mapping necessary.
-      xTreeNS->Search(querySet, k, neighbors, distances);
-      break;
-  }
+  SearchVisitor search(querySet, k, neighbors, distances, leafSize);
+  boost::apply_visitor(search, nSearch);
 }
 
 //! Perform neighbor search.
@@ -461,27 +392,8 @@ void NSModel<SortPolicy>::Search(const size_t k,
   else
     Log::Info << "brute-force (naive) search..." << std::endl;
 
-  switch (treeType)
-  {
-    case KD_TREE:
-      kdTreeNS->Search(k, neighbors, distances);
-      break;
-    case COVER_TREE:
-      coverTreeNS->Search(k, neighbors, distances);
-      break;
-    case R_TREE:
-      rTreeNS->Search(k, neighbors, distances);
-      break;
-    case R_STAR_TREE:
-      rStarTreeNS->Search(k, neighbors, distances);
-      break;
-    case BALL_TREE:
-      ballTreeNS->Search(k, neighbors, distances);
-      break;
-    case X_TREE:
-      xTreeNS->Search(k, neighbors, distances);
-      break;
-  }
+  SearchKVisitor search(k, neighbors, distances);
+  boost::apply_visitor(search, nSearch);
 }
 
 //! Get the name of the tree type.




More information about the mlpack-git mailing list