[mlpack-git] master: Refactor for new TreeType API. (5ee25ba)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:36 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit 5ee25baa085be32d7244f52ad7f742906112158e
Author: ryan <ryan at ratml.org>
Date: Sun Jul 26 23:07:37 2015 -0400
Refactor for new TreeType API.
>---------------------------------------------------------------
5ee25baa085be32d7244f52ad7f742906112158e
src/mlpack/methods/range_search/range_search.hpp | 32 +++---
.../methods/range_search/range_search_impl.hpp | 109 ++++++++++++---------
.../methods/range_search/range_search_main.cpp | 9 +-
3 files changed, 89 insertions(+), 61 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index b168cfe..04d3071 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -21,13 +21,21 @@ namespace range /** Range-search routines. */ {
* is implemented in the style of a generalized tree-independent dual-tree
* algorithm; for more details on the actual algorithm, see the RangeSearchRules
* class.
+ *
+ * @tparam MetricType Metric to use for range search calculations.
+ * @tparam MatType Type of data to use.
+ * @tparam TreeType Type of tree to use; must satisfy the TreeType policy API.
*/
-template<typename MetricType = mlpack::metric::EuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- RangeSearchStat> >
+template<typename MetricType = metric::EuclideanDistance,
+ typename MatType = arma::mat,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType = tree::KDTree>
class RangeSearch
{
public:
+ //! Convenience typedef.
+ typedef TreeType<MetricType, RangeSearchStat, MatType> Tree;
+
/**
* Initialize the RangeSearch object with a given reference dataset (this is
* the dataset which is searched). Optionally, perform the computation in
@@ -44,7 +52,7 @@ class RangeSearch
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
- RangeSearch(const typename TreeType::Mat& referenceSet,
+ RangeSearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -73,7 +81,7 @@ class RangeSearch
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
- RangeSearch(TreeType* referenceTree,
+ RangeSearch(Tree* referenceTree,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -110,7 +118,7 @@ class RangeSearch
* @param distances Object which will hold the list of distances for each
* point which fell into the given range, for each query point.
*/
- void Search(const typename TreeType::Mat& querySet,
+ void Search(const MatType& querySet,
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances);
@@ -151,7 +159,7 @@ class RangeSearch
* @param distances Object which will hold the list of distances for each
* point which fell into the given range, for each query point.
*/
- void Search(TreeType* queryTree,
+ void Search(Tree* queryTree,
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances);
@@ -195,11 +203,11 @@ class RangeSearch
private:
//! Copy of reference matrix; used when a tree is built internally.
- typename TreeType::Mat referenceCopy;
+ MatType referenceCopy;
//! Reference set (data should be accessed using this).
- const typename TreeType::Mat& referenceSet;
+ const MatType& referenceSet;
//! Reference tree.
- TreeType* referenceTree;
+ Tree* referenceTree;
//! Mappings to old reference indices (used when this object builds trees).
std::vector<size_t> oldFromNewReferences;
@@ -215,8 +223,8 @@ class RangeSearch
MetricType metric;
};
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
// Include implementation.
#include "range_search_impl.hpp"
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 97e7e20..a7ad027 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -39,13 +39,16 @@ TreeType* BuildTree(
return new TreeType(dataset);
}
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- const typename TreeType::Mat& referenceSetIn,
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+ const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
const MetricType metric) :
- referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+ referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
? referenceCopy : referenceSetIn),
referenceTree(NULL),
treeOwner(!naive), // If in naive mode, we are not building any trees.
@@ -60,23 +63,25 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
if (!naive)
{
// Copy the dataset, if it will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
referenceCopy = referenceSetIn;
// The const_cast is safe; if RearrangesDataset == false, then it'll be
// casted back to const anyway, and if not, referenceSet points to
// referenceCopy, which isn't const.
- referenceTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(referenceSet),
+ referenceTree = BuildTree<Tree>(const_cast<MatType&>(referenceSet),
oldFromNewReferences);
}
Timer::Stop("range_search/tree_building");
}
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- TreeType* referenceTree,
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+ Tree* referenceTree,
const bool singleMode,
const MetricType metric) :
referenceSet(referenceTree->Dataset()),
@@ -89,16 +94,22 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
// Nothing else to initialize.
}
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::~RangeSearch()
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
{
if (treeOwner && referenceTree)
delete referenceTree;
}
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
- const typename TreeType::Mat& querySet,
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
+ const MatType& querySet,
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances)
@@ -110,14 +121,13 @@ void RangeSearch<MetricType, TreeType>::Search(
// If we will be building a tree and it will modify the query set, make a copy
// of the dataset.
- typename TreeType::Mat queryCopy;
+ MatType queryCopy;
const bool needsCopy = (!naive && !singleMode &&
- tree::TreeTraits<TreeType>::RearrangesDataset);
+ tree::TreeTraits<Tree>::RearrangesDataset);
if (needsCopy)
queryCopy = querySet;
- const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
- querySet;
+ const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
// If we have built the trees ourselves, then we will have to map all the
// indices back to their original indices when this computation is finished.
@@ -127,7 +137,7 @@ void RangeSearch<MetricType, TreeType>::Search(
std::vector<std::vector<double>>* distancePtr = &distances;
// Mapping is only necessary if the tree rearranges points.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
// Query indices only need to be mapped if we are building the query tree
// ourselves.
@@ -147,7 +157,7 @@ void RangeSearch<MetricType, TreeType>::Search(
distancePtr->resize(querySet.n_cols);
// Create the helper object for the traversal.
- typedef RangeSearchRules<MetricType, TreeType> RuleType;
+ typedef RangeSearchRules<MetricType, Tree> RuleType;
RuleType rules(referenceSet, querySetRef, range, *neighborPtr, *distancePtr,
metric);
@@ -161,7 +171,7 @@ void RangeSearch<MetricType, TreeType>::Search(
else if (singleMode)
{
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -172,13 +182,13 @@ void RangeSearch<MetricType, TreeType>::Search(
// Build the query tree.
Timer::Stop("range_search/computing_neighbors");
Timer::Start("range_search/tree_building");
- TreeType* queryTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+ Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+ oldFromNewQueries);
Timer::Stop("range_search/tree_building");
Timer::Start("range_search/computing_neighbors");
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
@@ -189,7 +199,7 @@ void RangeSearch<MetricType, TreeType>::Search(
Timer::Stop("range_search/computing_neighbors");
// Map points back to original indices, if necessary.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
if (!singleMode && !naive && treeOwner)
{
@@ -255,9 +265,12 @@ void RangeSearch<MetricType, TreeType>::Search(
}
}
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
- TreeType* queryTree,
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
+ Tree* queryTree,
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances)
@@ -265,7 +278,7 @@ void RangeSearch<MetricType, TreeType>::Search(
Timer::Start("range_search/computing_neighbors");
// Get a reference to the query set.
- const typename TreeType::Mat& querySet = queryTree->Dataset();
+ const MatType& querySet = queryTree->Dataset();
// Make sure we are in dual-tree mode.
if (singleMode || naive)
@@ -275,7 +288,7 @@ void RangeSearch<MetricType, TreeType>::Search(
// We won't need to map query indices, but will we need to map distances?
std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
neighborPtr = new std::vector<std::vector<size_t>>;
// Resize each vector.
@@ -285,19 +298,19 @@ void RangeSearch<MetricType, TreeType>::Search(
distances.resize(querySet.n_cols);
// Create the helper object for the traversal.
- typedef RangeSearchRules<MetricType, TreeType> RuleType;
+ typedef RangeSearchRules<MetricType, Tree> RuleType;
RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
distances, metric);
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
Timer::Stop("range_search/computing_neighbors");
// Do we need to map indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
// We must map reference indices only.
neighbors.clear();
@@ -315,8 +328,11 @@ void RangeSearch<MetricType, TreeType>::Search(
}
}
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances)
@@ -327,7 +343,7 @@ void RangeSearch<MetricType, TreeType>::Search(
std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
std::vector<std::vector<double>>* distancePtr = &distances;
- if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+ if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
{
// We will always need to rearrange in this case.
distancePtr = new std::vector<std::vector<double>>;
@@ -341,7 +357,7 @@ void RangeSearch<MetricType, TreeType>::Search(
distancePtr->resize(referenceSet.n_cols);
// Create the helper object for the traversal.
- typedef RangeSearchRules<MetricType, TreeType> RuleType;
+ typedef RangeSearchRules<MetricType, Tree> RuleType;
RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
metric, true /* don't return the query point in the results */);
@@ -355,7 +371,7 @@ void RangeSearch<MetricType, TreeType>::Search(
else if (singleMode)
{
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -364,7 +380,7 @@ void RangeSearch<MetricType, TreeType>::Search(
else // Dual-tree recursion.
{
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*referenceTree, *referenceTree);
}
@@ -372,7 +388,7 @@ void RangeSearch<MetricType, TreeType>::Search(
Timer::Stop("range_search/computing_neighbors");
// Do we need to map the reference indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
neighbors.clear();
neighbors.resize(referenceSet.n_cols);
@@ -399,8 +415,11 @@ void RangeSearch<MetricType, TreeType>::Search(
}
}
-template<typename MetricType, typename TreeType>
-std::string RangeSearch<MetricType, TreeType>::ToString() const
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+std::string RangeSearch<MetricType, MatType, TreeType>::ToString() const
{
std::ostringstream convert;
convert << "Range Search [" << this << "]" << std::endl;
@@ -413,7 +432,7 @@ std::string RangeSearch<MetricType, TreeType>::ToString() const
return convert.str();
}
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 31eba17..783132c 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -16,6 +16,7 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::range;
using namespace mlpack::tree;
+using namespace mlpack::metric;
// Information about the program itself.
PROGRAM_INFO("Range Search",
@@ -66,9 +67,9 @@ PARAM_FLAG("cover_tree", "If true, use a cover tree for range searching "
"(instead of a kd-tree).", "c");
typedef RangeSearch<> RSType;
-typedef CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
- RangeSearchStat> CoverTreeType;
-typedef RangeSearch<metric::EuclideanDistance, CoverTreeType> RSCoverType;
+typedef CoverTree<EuclideanDistance, RangeSearchStat> CoverTreeType;
+typedef RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
+ RSCoverType;
int main(int argc, char *argv[])
{
@@ -162,7 +163,7 @@ int main(int argc, char *argv[])
}
else
{
- typedef BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat> TreeType;
+ typedef KDTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
// Track mappings.
Log::Info << "Building reference tree..." << endl;
More information about the mlpack-git
mailing list