[mlpack-git] master: Set SortPolicy as template parameter of classes BiSearchVisitor and TrainVisitor. (ddb252c)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 15 10:44:16 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a7e8d3bac60d6aa2717f27e0d4f6c53ff20607f5...779556fed748819a18cc898d9a6f69900740ef23
>---------------------------------------------------------------
commit ddb252ce03520b699e8335f5261ecd08b32d5a60
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Wed Jun 15 11:44:16 2016 -0300
Set SortPolicy as template parameter of classes BiSearchVisitor and TrainVisitor.
Also, add a more specific definition of NSTypeT, to avoid vc compiler errors.
>---------------------------------------------------------------
ddb252ce03520b699e8335f5261ecd08b32d5a60
.../methods/neighbor_search/neighbor_search.hpp | 3 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 124 +++++++++++----------
.../methods/neighbor_search/ns_model_impl.hpp | 47 ++++----
3 files changed, 93 insertions(+), 81 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index bf62ae7..999f261 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,6 +27,7 @@ namespace neighbor /** Neighbor-search routines. These include
* searches. */ {
// Forward declaration.
+template<typename SortPolicy>
class TrainVisitor;
/**
@@ -307,7 +308,7 @@ class NeighborSearch
bool treeNeedsReset;
//! The NSModel class should have access to internal members.
- friend class TrainVisitor;
+ friend class TrainVisitor<SortPolicy>;
}; // class NeighborSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index f269e13..613ef45 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -52,116 +52,122 @@ struct NSModelName<FurthestNeighborSort>
class MonoSearchVisitor : public boost::static_visitor<void>
{
private:
- const size_t k;
- arma::Mat<size_t>& neighbors;
- arma::mat& distances;
+ const size_t k;
+ arma::Mat<size_t>& neighbors;
+ arma::mat& distances;
public:
- template<typename NSType>
- void operator()(NSType *ns) const;
+ template<typename NSType>
+ void operator()(NSType *ns) const;
- MonoSearchVisitor(const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ MonoSearchVisitor(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
};
+template<typename SortPolicy>
class BiSearchVisitor : public boost::static_visitor<void>
{
private:
- const arma::mat& querySet;
- const size_t k;
- arma::Mat<size_t>& neighbors;
- arma::mat& distances;
- const size_t leafSize;
+ const arma::mat& querySet;
+ const size_t k;
+ arma::Mat<size_t>& neighbors;
+ arma::mat& distances;
+ const size_t leafSize;
- template<typename NSType>
- void SearchLeaf(NSType *ns) const;
+ template<typename NSType>
+ void SearchLeaf(NSType* ns) const;
public:
- template<typename SortPolicy,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType>
- void operator()(NSType<SortPolicy,TreeType> *ns) const;
-
- template<typename SortPolicy>
- void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
-
- template<typename SortPolicy>
- void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
-
- BiSearchVisitor(const arma::mat& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- const size_t leafSize);
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using NSTypeT = NSType<SortPolicy, TreeType>;
+
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(NSTypeT<TreeType>* ns) const;
+
+ void operator()(NSTypeT<tree::KDTree>* ns) const;
+
+ void operator()(NSTypeT<tree::BallTree>* ns) const;
+
+ BiSearchVisitor(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ const size_t leafSize);
};
+template<typename SortPolicy>
class TrainVisitor : public boost::static_visitor<void>
{
private:
- arma::mat&& referenceSet;
- size_t leafSize;
+ arma::mat&& referenceSet;
+ size_t leafSize;
- template<typename NSType>
- void TrainLeaf(NSType* ns) const;
+ template<typename NSType>
+ void TrainLeaf(NSType* ns) const;
public:
- template<typename SortPolicy,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType>
- void operator()(NSType<SortPolicy,TreeType> *ns) const;
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using NSTypeT = NSType<SortPolicy, TreeType>;
+
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(NSTypeT<TreeType>* ns) const;
- template<typename SortPolicy>
- void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+ void operator()(NSTypeT<tree::KDTree>* ns) const;
- template<typename SortPolicy>
- void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+ void operator()(NSTypeT<tree::BallTree>* ns) const;
- TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+ TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
};
class SingleModeVisitor : public boost::static_visitor<bool&>
{
public:
- template<typename NSType>
- bool& operator()(NSType *ns) const;
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
};
class NaiveVisitor : public boost::static_visitor<bool&>
{
public:
- template<typename NSType>
- bool& operator()(NSType *ns) const;
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
};
class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
{
public:
- template<typename NSType>
- const arma::mat& operator()(NSType *ns) const;
+ template<typename NSType>
+ const arma::mat& operator()(NSType *ns) const;
};
class DeleteVisitor : public boost::static_visitor<void>
{
public:
- template<typename NSType>
- void operator()(NSType *ns) const;
+ template<typename NSType>
+ void operator()(NSType *ns) const;
};
template<typename Archive>
class SerializeVisitor : public boost::static_visitor<void>
{
private:
- Archive& ar;
- const std::string& name;
+ Archive& ar;
+ const std::string& name;
public:
- template<typename NSType>
- void operator()(NSType *ns) const;
+ template<typename NSType>
+ void operator()(NSType *ns) const;
- SerializeVisitor(Archive& ar, const std::string& name);
+ SerializeVisitor(Archive& ar, const std::string& name);
};
template<typename SortPolicy>
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index ef67d18..bcded0e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -32,12 +32,12 @@ void MonoSearchVisitor::operator()(NSType *ns) const
throw std::runtime_error("no neighbor search model initialized");
}
-
-BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- const size_t leafSize) :
+template<typename SortPolicy>
+BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ const size_t leafSize) :
querySet(querySet),
k(k),
neighbors(neighbors),
@@ -45,11 +45,11 @@ BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet,
leafSize(leafSize)
{}
-template<typename SortPolicy,
- template<typename TreeMetricType,
+template<typename SortPolicy>
+template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
-void BiSearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
{
if (ns)
return ns->Search(querySet, k, neighbors, distances);
@@ -57,7 +57,7 @@ void BiSearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
}
template<typename SortPolicy>
-void BiSearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
{
if (ns)
return SearchLeaf(ns);
@@ -65,15 +65,16 @@ void BiSearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
}
template<typename SortPolicy>
-void BiSearchVisitor::operator()(NSType<SortPolicy,tree::BallTree> *ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
{
if (ns)
return SearchLeaf(ns);
throw std::runtime_error("no neighbor search model initialized");
}
+template<typename SortPolicy>
template<typename NSType>
-void BiSearchVisitor::SearchLeaf(NSType *ns) const
+void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType *ns) const
{
if (!ns->Naive() && !ns->SingleMode())
{
@@ -99,16 +100,18 @@ void BiSearchVisitor::SearchLeaf(NSType *ns) const
}
-TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) :
+template<typename SortPolicy>
+TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
+ const size_t leafSize) :
referenceSet(std::move(referenceSet)),
leafSize(leafSize)
{}
-template<typename SortPolicy,
- template<typename TreeMetricType,
+template<typename SortPolicy>
+template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
-void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
{
if (ns)
return ns->Train(std::move(referenceSet));
@@ -116,7 +119,7 @@ void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
}
template<typename SortPolicy>
-void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
{
if (ns)
return TrainLeaf(ns);
@@ -124,15 +127,16 @@ void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
}
template<typename SortPolicy>
-void TrainVisitor::operator ()(NSType<SortPolicy,tree::BallTree> *ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
{
if (ns)
return TrainLeaf(ns);
throw std::runtime_error("no neighbor search model initialized");
}
+template<typename SortPolicy>
template<typename NSType>
-void TrainVisitor::TrainLeaf(NSType* ns) const
+void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
{
if (ns->Naive())
ns->Train(std::move(referenceSet));
@@ -345,7 +349,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
break;
}
- TrainVisitor tn(std::move(referenceSet),leafSize);
+ TrainVisitor<SortPolicy> tn(std::move(referenceSet),leafSize);
boost::apply_visitor(tn, nSearch);
if (!naive)
@@ -374,7 +378,8 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
else
Log::Info << "brute-force (naive) search..." << std::endl;
- BiSearchVisitor search(querySet, k, neighbors, distances, leafSize);
+ BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
+ leafSize);
boost::apply_visitor(search, nSearch);
}
More information about the mlpack-git
mailing list