[mlpack-git] master: Add a new template parameter to NeighborSearch class, to set a specific SingleTreeTraverser. (70fbeab)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Aug 17 00:15:19 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 70fbeab1e0be29113f0e21c688601133c23d4751
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Wed Aug 17 01:15:19 2016 -0300
Add a new template parameter to NeighborSearch class, to set a specific SingleTreeTraverser.
>---------------------------------------------------------------
70fbeab1e0be29113f0e21c688601133c23d4751
.../methods/neighbor_search/neighbor_search.hpp | 6 +-
.../neighbor_search/neighbor_search_impl.hpp | 153 +++++++++++----------
.../methods/neighbor_search/ns_model_impl.hpp | 6 +-
3 files changed, 93 insertions(+), 72 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index be16b05..9b5b3be 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -67,7 +67,11 @@ template<typename SortPolicy = NearestNeighborSort,
template<typename RuleType> class TraversalType =
TreeType<MetricType,
NeighborSearchStat<SortPolicy>,
- MatType>::template DualTreeTraverser>
+ MatType>::template DualTreeTraverser,
+ template<typename RuleType> class SingleTreeTraversalType =
+ TreeType<MetricType,
+ NeighborSearchStat<SortPolicy>,
+ MatType>::template SingleTreeTraverser>
class NeighborSearch
{
public:
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 238680e..8968e63 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -70,13 +70,14 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(const MatType& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const double epsilon,
- const MetricType metric) :
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const double epsilon,
+ const MetricType metric) :
referenceTree(naive ? NULL :
BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
@@ -101,13 +102,14 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(MatType&& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const double epsilon,
- const MetricType metric) :
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const double epsilon,
+ const MetricType metric) :
referenceTree(naive ? NULL :
BuildTree<MatType, Tree>(std::move(referenceSetIn),
oldFromNewReferences)),
@@ -134,12 +136,13 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(Tree* referenceTree,
- const bool singleMode,
- const double epsilon,
- const MetricType metric) :
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
+ const bool singleMode,
+ const double epsilon,
+ const MetricType metric) :
referenceTree(referenceTree),
referenceSet(&referenceTree->Dataset()),
treeOwner(false),
@@ -163,12 +166,13 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
- NeighborSearch(const bool naive,
- const bool singleMode,
- const double epsilon,
- const MetricType metric) :
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(const bool naive,
+ const bool singleMode,
+ const double epsilon,
+ const MetricType metric) :
referenceTree(NULL),
referenceSet(new MatType()), // Empty matrix.
treeOwner(false),
@@ -199,9 +203,10 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-~NeighborSearch()
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::~NeighborSearch()
{
if (treeOwner && referenceTree)
delete referenceTree;
@@ -215,9 +220,10 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(const MatType& referenceSet)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(const MatType& referenceSet)
{
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
@@ -252,9 +258,10 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(MatType&& referenceSetIn)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
{
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
@@ -294,9 +301,10 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(Tree* referenceTree)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(Tree* referenceTree)
{
if (naive)
throw std::invalid_argument("cannot train on given reference tree when "
@@ -323,12 +331,13 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(const MatType& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(const MatType& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
if (k > referenceSet->n_cols)
{
@@ -391,7 +400,7 @@ Search(const MatType& querySet,
RuleType rules(*referenceSet, querySet, k, metric, epsilon);
// Create the traverser.
- typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
+ SingleTreeTraversalType<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -505,13 +514,14 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(Tree* queryTree,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- bool sameSet)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(Tree* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ bool sameSet)
{
if (k > referenceSet->n_cols)
{
@@ -586,11 +596,12 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
if (k > referenceSet->n_cols)
{
@@ -636,7 +647,7 @@ Search(const size_t k,
else if (singleMode)
{
// Create the traverser.
- typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
+ SingleTreeTraversalType<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < referenceSet->n_cols; ++i)
@@ -723,10 +734,11 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
-TraversalType>::EffectiveError(arma::mat& foundDistances,
- arma::mat& realDistances)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::EffectiveError(arma::mat& foundDistances,
+ arma::mat& realDistances)
{
if (foundDistances.n_rows != realDistances.n_rows ||
foundDistances.n_cols != realDistances.n_cols)
@@ -759,10 +771,11 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
-double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
-TraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
- arma::Mat<size_t>& realNeighbors)
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
+ arma::Mat<size_t>& realNeighbors)
{
if (foundNeighbors.n_rows != realNeighbors.n_rows ||
foundNeighbors.n_cols != realNeighbors.n_cols)
@@ -788,10 +801,12 @@ template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename> class TraversalType>
+ template<typename> class TraversalType,
+ template<typename> class SingleTreeTraversalType>
template<typename Archive>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
- Serialize(Archive& ar, const unsigned int /* version */)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
{
using data::CreateNVP;
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 807ea57..79c8772 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -284,14 +284,16 @@ template<typename Archive,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
- template<typename RuleType> class TraversalType>
+ template<typename RuleType> class TraversalType,
+ template<typename RuleType> class SingleTreeTraversalType>
void serialize(
Archive& ar,
NeighborSearch<SortPolicy,
metric::EuclideanDistance,
arma::mat,
TreeType,
- TraversalType>& ns,
+ TraversalType,
+ SingleTreeTraversalType>& ns,
const unsigned int version)
{
ns.Serialize(ar, version);
More information about the mlpack-git
mailing list