[mlpack-git] master: Allow the dual-tree traverser to be changed easily. This probably should be done for the single-tree traverser too, but we'll start here for now. (0f31abb)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat May 2 20:39:22 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4379c3398c5a3e6d59a53183445a5bb932506f01...0f31abbdebcd34e2113d8acf47c1d0b087377921
>---------------------------------------------------------------
commit 0f31abbdebcd34e2113d8acf47c1d0b087377921
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun May 3 00:38:50 2015 +0000
Allow the dual-tree traverser to be changed easily.
This probably should be done for the single-tree traverser too, but we'll start here for now.
>---------------------------------------------------------------
0f31abbdebcd34e2113d8acf47c1d0b087377921
.../methods/neighbor_search/neighbor_search.hpp | 6 ++-
.../neighbor_search/neighbor_search_impl.hpp | 63 +++++++++++++++-------
2 files changed, 48 insertions(+), 21 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 80306e1..abdf08a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -14,10 +14,12 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
#include "neighbor_search_stat.hpp"
#include "sort_policies/nearest_neighbor_sort.hpp"
+#include "neighbor_search_rules.hpp"
namespace mlpack {
namespace neighbor /** Neighbor-search routines. These include
@@ -45,7 +47,9 @@ namespace neighbor /** Neighbor-search routines. These include
template<typename SortPolicy = NearestNeighborSort,
typename MetricType = mlpack::metric::SquaredEuclideanDistance,
typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<SortPolicy> > >
+ NeighborSearchStat<SortPolicy>>,
+ template<typename RuleType> class TraversalType =
+ TreeType::template DualTreeTraverser>
class NeighborSearch
{
public:
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 803f0f4..0664808 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -40,8 +40,11 @@ TreeType* BuildTree(
}
// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
NeighborSearch(const typename TreeType::Mat& referenceSetIn,
const bool naive,
const bool singleMode,
@@ -78,11 +81,14 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
}
// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
- TreeType* referenceTree,
- const bool singleMode,
- const MetricType metric) :
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+NeighborSearch(TreeType* referenceTree,
+ const bool singleMode,
+ const MetricType metric) :
referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
treeOwner(false),
@@ -96,8 +102,12 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
}
// Clean memory.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+ ~NeighborSearch()
{
if (treeOwner && referenceTree)
delete referenceTree;
@@ -107,8 +117,11 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
* Computes the best neighbors and stores them in resultingNeighbors and
* distances.
*/
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
const typename TreeType::Mat& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
@@ -192,7 +205,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
Timer::Start("computing_neighbors");
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ TraversalType<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
@@ -267,8 +280,11 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
} // Search()
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
TreeType* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
@@ -300,7 +316,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ TraversalType<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
Timer::Stop("computing_neighbors");
@@ -321,8 +337,11 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -374,7 +393,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
else
{
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ TraversalType<RuleType> traverser(rules);
traverser.Traverse(*referenceTree, *referenceTree);
@@ -408,8 +427,12 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
// Return a String of the Object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-std::string NeighborSearch<SortPolicy, MetricType, TreeType>::ToString() const
+template<typename SortPolicy,
+ typename MetricType,
+ typename TreeType,
+ template<typename> class TraversalType>
+std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+ ToString() const
{
std::ostringstream convert;
convert << "NeighborSearch [" << this << "]" << std::endl;
More information about the mlpack-git
mailing list