[mlpack-git] master: Add a new template parameter to NeighborSearch class, to set a specific SingleTreeTraverser. (70fbeab)

gitdub at mlpack.org gitdub at mlpack.org
Wed Aug 17 00:15:19 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 70fbeab1e0be29113f0e21c688601133c23d4751
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Wed Aug 17 01:15:19 2016 -0300

    Add a new template parameter to NeighborSearch class, to set a specific SingleTreeTraverser.


>---------------------------------------------------------------

70fbeab1e0be29113f0e21c688601133c23d4751
 .../methods/neighbor_search/neighbor_search.hpp    |   6 +-
 .../neighbor_search/neighbor_search_impl.hpp       | 153 +++++++++++----------
 .../methods/neighbor_search/ns_model_impl.hpp      |   6 +-
 3 files changed, 93 insertions(+), 72 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index be16b05..9b5b3be 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -67,7 +67,11 @@ template<typename SortPolicy = NearestNeighborSort,
          template<typename RuleType> class TraversalType =
              TreeType<MetricType,
                       NeighborSearchStat<SortPolicy>,
-                      MatType>::template DualTreeTraverser>
+                      MatType>::template DualTreeTraverser,
+         template<typename RuleType> class SingleTreeTraversalType =
+             TreeType<MetricType,
+                      NeighborSearchStat<SortPolicy>,
+                      MatType>::template SingleTreeTraverser>
 class NeighborSearch
 {
  public:
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 238680e..8968e63 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -70,13 +70,14 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(const MatType& referenceSetIn,
-               const bool naive,
-               const bool singleMode,
-               const double epsilon,
-               const MetricType metric) :
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
+                                         const bool naive,
+                                         const bool singleMode,
+                                         const double epsilon,
+                                         const MetricType metric) :
     referenceTree(naive ? NULL :
         BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
     referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
@@ -101,13 +102,14 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(MatType&& referenceSetIn,
-               const bool naive,
-               const bool singleMode,
-               const double epsilon,
-               const MetricType metric) :
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
+                                         const bool naive,
+                                         const bool singleMode,
+                                         const double epsilon,
+                                         const MetricType metric) :
     referenceTree(naive ? NULL :
         BuildTree<MatType, Tree>(std::move(referenceSetIn),
                                  oldFromNewReferences)),
@@ -134,12 +136,13 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-NeighborSearch(Tree* referenceTree,
-               const bool singleMode,
-               const double epsilon,
-               const MetricType metric) :
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
+                                         const bool singleMode,
+                                         const double epsilon,
+                                         const MetricType metric) :
     referenceTree(referenceTree),
     referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
@@ -163,12 +166,13 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-    NeighborSearch(const bool naive,
-                   const bool singleMode,
-                   const double epsilon,
-                   const MetricType metric) :
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::NeighborSearch(const bool naive,
+                                         const bool singleMode,
+                                         const double epsilon,
+                                         const MetricType metric) :
     referenceTree(NULL),
     referenceSet(new MatType()), // Empty matrix.
     treeOwner(false),
@@ -199,9 +203,10 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-~NeighborSearch()
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::~NeighborSearch()
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
@@ -215,9 +220,10 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(const MatType& referenceSet)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(const MatType& referenceSet)
 {
   // Clean up the old tree, if we built one.
   if (treeOwner && referenceTree)
@@ -252,9 +258,10 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(MatType&& referenceSetIn)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
 {
   // Clean up the old tree, if we built one.
   if (treeOwner && referenceTree)
@@ -294,9 +301,10 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Train(Tree* referenceTree)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Train(Tree* referenceTree)
 {
   if (naive)
     throw std::invalid_argument("cannot train on given reference tree when "
@@ -323,12 +331,13 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(const MatType& querySet,
-       const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(const MatType& querySet,
+                                 const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
 {
   if (k > referenceSet->n_cols)
   {
@@ -391,7 +400,7 @@ Search(const MatType& querySet,
     RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
     // Create the traverser.
-    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
+    SingleTreeTraversalType<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -505,13 +514,14 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(Tree* queryTree,
-       const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances,
-       bool sameSet)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(Tree* queryTree,
+                                 const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances,
+                                 bool sameSet)
 {
   if (k > referenceSet->n_cols)
   {
@@ -586,11 +596,12 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-Search(const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Search(const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
 {
   if (k > referenceSet->n_cols)
   {
@@ -636,7 +647,7 @@ Search(const size_t k,
   else if (singleMode)
   {
     // Create the traverser.
-    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
+    SingleTreeTraversalType<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < referenceSet->n_cols; ++i)
@@ -723,10 +734,11 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
-TraversalType>::EffectiveError(arma::mat& foundDistances,
-                               arma::mat& realDistances)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::EffectiveError(arma::mat& foundDistances,
+                                         arma::mat& realDistances)
 {
   if (foundDistances.n_rows != realDistances.n_rows ||
       foundDistances.n_cols != realDistances.n_cols)
@@ -759,10 +771,11 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
-double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
-TraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
-                       arma::Mat<size_t>& realNeighbors)
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
+double NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Recall(arma::Mat<size_t>& foundNeighbors,
+                                 arma::Mat<size_t>& realNeighbors)
 {
   if (foundNeighbors.n_rows != realNeighbors.n_rows ||
       foundNeighbors.n_cols != realNeighbors.n_cols)
@@ -788,10 +801,12 @@ template<typename SortPolicy,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename> class TraversalType>
+         template<typename> class TraversalType,
+         template<typename> class SingleTreeTraversalType>
 template<typename Archive>
-void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
-    Serialize(Archive& ar, const unsigned int /* version */)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType,
+SingleTreeTraversalType>::Serialize(Archive& ar,
+                                    const unsigned int /* version */)
 {
   using data::CreateNVP;
 
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 807ea57..79c8772 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -284,14 +284,16 @@ template<typename Archive,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType,
-         template<typename RuleType> class TraversalType>
+         template<typename RuleType> class TraversalType,
+         template<typename RuleType> class SingleTreeTraversalType>
 void serialize(
     Archive& ar,
     NeighborSearch<SortPolicy,
                    metric::EuclideanDistance,
                    arma::mat,
                    TreeType,
-                   TraversalType>& ns,
+                   TraversalType,
+                   SingleTreeTraversalType>& ns,
     const unsigned int version)
 {
   ns.Serialize(ar, version);




More information about the mlpack-git mailing list