[mlpack-git] master: Allow the dual-tree traverser to be changed easily. This probably should be done for the single-tree traverser too, but we'll start here for now. (0f31abb)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat May 2 20:39:22 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4379c3398c5a3e6d59a53183445a5bb932506f01...0f31abbdebcd34e2113d8acf47c1d0b087377921

>---------------------------------------------------------------

commit 0f31abbdebcd34e2113d8acf47c1d0b087377921
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun May 3 00:38:50 2015 +0000

    Allow the dual-tree traverser to be changed easily.
    This probably should be done for the single-tree traverser too, but we'll start here for now.


>---------------------------------------------------------------

0f31abbdebcd34e2113d8acf47c1d0b087377921
 .../methods/neighbor_search/neighbor_search.hpp    |  6 ++-
 .../neighbor_search/neighbor_search_impl.hpp       | 63 +++++++++++++++-------
 2 files changed, 48 insertions(+), 21 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 80306e1..abdf08a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -14,10 +14,12 @@
 
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
 
 #include <mlpack/core/metrics/lmetric.hpp>
 #include "neighbor_search_stat.hpp"
 #include "sort_policies/nearest_neighbor_sort.hpp"
+#include "neighbor_search_rules.hpp"
 
 namespace mlpack {
 namespace neighbor /** Neighbor-search routines.  These include
@@ -45,7 +47,9 @@ namespace neighbor /** Neighbor-search routines.  These include
 template<typename SortPolicy = NearestNeighborSort,
          typename MetricType = mlpack::metric::SquaredEuclideanDistance,
          typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-             NeighborSearchStat<SortPolicy> > >
+             NeighborSearchStat<SortPolicy>>,
+         template<typename RuleType> class TraversalType =
+             TreeType::template DualTreeTraverser>
 class NeighborSearch
 {
  public:
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 803f0f4..0664808 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -40,8 +40,11 @@ TreeType* BuildTree(
 }
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
 NeighborSearch(const typename TreeType::Mat& referenceSetIn,
                const bool naive,
                const bool singleMode,
@@ -78,11 +81,14 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
 }
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
-    TreeType* referenceTree,
-    const bool singleMode,
-    const MetricType metric) :
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+NeighborSearch(TreeType* referenceTree,
+               const bool singleMode,
+               const MetricType metric) :
     referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
     treeOwner(false),
@@ -96,8 +102,12 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
 }
 
 // Clean memory.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+    ~NeighborSearch()
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
@@ -107,8 +117,11 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
  * Computes the best neighbors and stores them in resultingNeighbors and
  * distances.
  */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
     const typename TreeType::Mat& querySet,
     const size_t k,
     arma::Mat<size_t>& neighbors,
@@ -192,7 +205,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
     Timer::Start("computing_neighbors");
 
     // Create the traverser.
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    TraversalType<RuleType> traverser(rules);
 
     traverser.Traverse(*queryTree, *referenceTree);
 
@@ -267,8 +280,11 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
 } // Search()
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
     TreeType* queryTree,
     const size_t k,
     arma::Mat<size_t>& neighbors,
@@ -300,7 +316,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
 
   // Create the traverser.
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+  TraversalType<RuleType> traverser(rules);
   traverser.Traverse(*queryTree, *referenceTree);
 
   Timer::Stop("computing_neighbors");
@@ -321,8 +337,11 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
     const size_t k,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
@@ -374,7 +393,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   else
   {
     // Create the traverser.
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    TraversalType<RuleType> traverser(rules);
 
     traverser.Traverse(*referenceTree, *referenceTree);
 
@@ -408,8 +427,12 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
 }
 
 // Return a String of the Object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-std::string NeighborSearch<SortPolicy, MetricType, TreeType>::ToString() const
+template<typename SortPolicy,
+         typename MetricType,
+         typename TreeType,
+         template<typename> class TraversalType>
+std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+    ToString() const
 {
   std::ostringstream convert;
   convert << "NeighborSearch [" << this << "]" << std::endl;



More information about the mlpack-git mailing list