[mlpack-git] master: Set SortPolicy as template parameter of classes BiSearchVisitor and TrainVisitor. (ddb252c)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 15 10:44:16 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a7e8d3bac60d6aa2717f27e0d4f6c53ff20607f5...779556fed748819a18cc898d9a6f69900740ef23

>---------------------------------------------------------------

commit ddb252ce03520b699e8335f5261ecd08b32d5a60
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Wed Jun 15 11:44:16 2016 -0300

    Set SortPolicy as template parameter of classes BiSearchVisitor and TrainVisitor.
    
    Also, add a more specific definition of NSTypeT, to avoid vc compiler errors.


>---------------------------------------------------------------

ddb252ce03520b699e8335f5261ecd08b32d5a60
 .../methods/neighbor_search/neighbor_search.hpp    |   3 +-
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 124 +++++++++++----------
 .../methods/neighbor_search/ns_model_impl.hpp      |  47 ++++----
 3 files changed, 93 insertions(+), 81 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index bf62ae7..999f261 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,6 +27,7 @@ namespace neighbor /** Neighbor-search routines.  These include
                     * searches. */ {
 
 // Forward declaration.
+template<typename SortPolicy>
 class TrainVisitor;
 
 /**
@@ -307,7 +308,7 @@ class NeighborSearch
   bool treeNeedsReset;
 
   //! The NSModel class should have access to internal members.
-  friend class TrainVisitor;
+  friend class TrainVisitor<SortPolicy>;
 }; // class NeighborSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index f269e13..613ef45 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -52,116 +52,122 @@ struct NSModelName<FurthestNeighborSort>
 class MonoSearchVisitor : public boost::static_visitor<void>
 {
  private:
-   const size_t k;
-   arma::Mat<size_t>& neighbors;
-   arma::mat& distances;
+  const size_t k;
+  arma::Mat<size_t>& neighbors;
+  arma::mat& distances;
 
  public:
-   template<typename NSType>
-   void operator()(NSType *ns) const;
+  template<typename NSType>
+  void operator()(NSType *ns) const;
 
-   MonoSearchVisitor(const size_t k,
-                     arma::Mat<size_t>& neighbors,
-                     arma::mat& distances);
+  MonoSearchVisitor(const size_t k,
+                    arma::Mat<size_t>& neighbors,
+                    arma::mat& distances);
 };
 
+template<typename SortPolicy>
 class BiSearchVisitor : public boost::static_visitor<void>
 {
  private:
-   const arma::mat& querySet;
-   const size_t k;
-   arma::Mat<size_t>& neighbors;
-   arma::mat& distances;
-   const size_t leafSize;
+  const arma::mat& querySet;
+  const size_t k;
+  arma::Mat<size_t>& neighbors;
+  arma::mat& distances;
+  const size_t leafSize;
 
-   template<typename NSType>
-   void SearchLeaf(NSType *ns) const;
+  template<typename NSType>
+  void SearchLeaf(NSType* ns) const;
 
  public:
-   template<typename SortPolicy,
-            template<typename TreeMetricType,
-                     typename TreeStatType,
-                     typename TreeMatType> class TreeType>
-   void operator()(NSType<SortPolicy,TreeType> *ns) const;
-
-   template<typename SortPolicy>
-   void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
-
-   template<typename SortPolicy>
-   void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
-
-   BiSearchVisitor(const arma::mat& querySet,
-                   const size_t k,
-                   arma::Mat<size_t>& neighbors,
-                   arma::mat& distances,
-                   const size_t leafSize);
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using NSTypeT = NSType<SortPolicy, TreeType>;
+
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  void operator()(NSTypeT<TreeType>* ns) const;
+
+  void operator()(NSTypeT<tree::KDTree>* ns) const;
+
+  void operator()(NSTypeT<tree::BallTree>* ns) const;
+
+  BiSearchVisitor(const arma::mat& querySet,
+                  const size_t k,
+                  arma::Mat<size_t>& neighbors,
+                  arma::mat& distances,
+                  const size_t leafSize);
 };
 
+template<typename SortPolicy>
 class TrainVisitor : public boost::static_visitor<void>
 {
  private:
-   arma::mat&& referenceSet;
-   size_t leafSize;
+  arma::mat&& referenceSet;
+  size_t leafSize;
 
-   template<typename NSType>
-   void TrainLeaf(NSType* ns) const;
+  template<typename NSType>
+  void TrainLeaf(NSType* ns) const;
 
  public:
-   template<typename SortPolicy,
-            template<typename TreeMetricType,
-                     typename TreeStatType,
-                     typename TreeMatType> class TreeType>
-   void operator()(NSType<SortPolicy,TreeType> *ns) const;
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using NSTypeT = NSType<SortPolicy, TreeType>;
+
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  void operator()(NSTypeT<TreeType>* ns) const;
 
-   template<typename SortPolicy>
-   void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+  void operator()(NSTypeT<tree::KDTree>* ns) const;
 
-   template<typename SortPolicy>
-   void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+  void operator()(NSTypeT<tree::BallTree>* ns) const;
 
-   TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+  TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
 };
 
 class SingleModeVisitor : public boost::static_visitor<bool&>
 {
  public:
-   template<typename NSType>
-   bool& operator()(NSType *ns) const;
+  template<typename NSType>
+  bool& operator()(NSType *ns) const;
 };
 
 class NaiveVisitor : public boost::static_visitor<bool&>
 {
  public:
-   template<typename NSType>
-   bool& operator()(NSType *ns) const;
+  template<typename NSType>
+  bool& operator()(NSType *ns) const;
 };
 
 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
 {
  public:
-   template<typename NSType>
-   const arma::mat& operator()(NSType *ns) const;
+  template<typename NSType>
+  const arma::mat& operator()(NSType *ns) const;
 };
 
 class DeleteVisitor : public boost::static_visitor<void>
 {
  public:
-   template<typename NSType>
-   void operator()(NSType *ns) const;
+  template<typename NSType>
+  void operator()(NSType *ns) const;
 };
 
 template<typename Archive>
 class SerializeVisitor : public boost::static_visitor<void>
 {
  private:
-   Archive& ar;
-   const std::string& name;
+  Archive& ar;
+  const std::string& name;
 
  public:
-   template<typename NSType>
-   void operator()(NSType *ns) const;
+  template<typename NSType>
+  void operator()(NSType *ns) const;
 
-   SerializeVisitor(Archive& ar, const std::string& name);
+  SerializeVisitor(Archive& ar, const std::string& name);
 };
 
 template<typename SortPolicy>
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index ef67d18..bcded0e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -32,12 +32,12 @@ void MonoSearchVisitor::operator()(NSType *ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
-
-BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
-                                 const size_t k,
-                                 arma::Mat<size_t>& neighbors,
-                                 arma::mat& distances,
-                                 const size_t leafSize) :
+template<typename SortPolicy>
+BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
+                                             const size_t k,
+                                             arma::Mat<size_t>& neighbors,
+                                             arma::mat& distances,
+                                             const size_t leafSize) :
     querySet(querySet),
     k(k),
     neighbors(neighbors),
@@ -45,11 +45,11 @@ BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
     leafSize(leafSize)
 {}
 
-template<typename SortPolicy,
-         template<typename TreeMetricType,
+template<typename SortPolicy>
+template<template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
-void BiSearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
 {
   if (ns)
     return ns->Search(querySet, k, neighbors, distances);
@@ -57,7 +57,7 @@ void BiSearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
 }
 
 template<typename SortPolicy>
-void BiSearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
 {
   if (ns)
     return SearchLeaf(ns);
@@ -65,15 +65,16 @@ void BiSearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
 }
 
 template<typename SortPolicy>
-void BiSearchVisitor::operator()(NSType<SortPolicy,tree::BallTree> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
 {
   if (ns)
     return SearchLeaf(ns);
   throw std::runtime_error("no neighbor search model initialized");
 }
 
+template<typename SortPolicy>
 template<typename NSType>
-void BiSearchVisitor::SearchLeaf(NSType *ns) const
+void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType *ns) const
 {
   if (!ns->Naive() && !ns->SingleMode())
   {
@@ -99,16 +100,18 @@ void BiSearchVisitor::SearchLeaf(NSType *ns) const
 }
 
 
-TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) :
+template<typename SortPolicy>
+TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
+                                       const size_t leafSize) :
     referenceSet(std::move(referenceSet)),
     leafSize(leafSize)
 {}
 
-template<typename SortPolicy,
-         template<typename TreeMetricType,
+template<typename SortPolicy>
+template<template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
-void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
 {
   if (ns)
     return ns->Train(std::move(referenceSet));
@@ -116,7 +119,7 @@ void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
 }
 
 template<typename SortPolicy>
-void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
 {
   if (ns)
     return TrainLeaf(ns);
@@ -124,15 +127,16 @@ void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
 }
 
 template<typename SortPolicy>
-void TrainVisitor::operator ()(NSType<SortPolicy,tree::BallTree> *ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
 {
   if (ns)
     return TrainLeaf(ns);
   throw std::runtime_error("no neighbor search model initialized");
 }
 
+template<typename SortPolicy>
 template<typename NSType>
-void TrainVisitor::TrainLeaf(NSType* ns) const
+void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
 {
   if (ns->Naive())
     ns->Train(std::move(referenceSet));
@@ -345,7 +349,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
       break;
   }
 
-  TrainVisitor tn(std::move(referenceSet),leafSize);
+  TrainVisitor<SortPolicy> tn(std::move(referenceSet),leafSize);
   boost::apply_visitor(tn, nSearch);
 
   if (!naive)
@@ -374,7 +378,8 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
   else
     Log::Info << "brute-force (naive) search..." << std::endl;
 
-  BiSearchVisitor search(querySet, k, neighbors, distances, leafSize);
+  BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
+      leafSize);
   boost::apply_visitor(search, nSearch);
 }
 




More information about the mlpack-git mailing list