[mlpack-git] master: Refactor NeighborSearch to new TreeType API. (aa28b5e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:05 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit aa28b5e4fd91e6b21a08309db851f1fc161041bb
Author: ryan <ryan at ratml.org>
Date: Fri Jul 24 14:26:57 2015 -0400
Refactor NeighborSearch to new TreeType API.
>---------------------------------------------------------------
aa28b5e4fd91e6b21a08309db851f1fc161041bb
src/mlpack/methods/neighbor_search/allkfn_main.cpp | 17 ++--
src/mlpack/methods/neighbor_search/allknn_main.cpp | 25 +++--
.../methods/neighbor_search/neighbor_search.hpp | 30 +++---
.../neighbor_search/neighbor_search_impl.hpp | 110 ++++++++++++---------
4 files changed, 98 insertions(+), 84 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/allkfn_main.cpp b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
index 28d79bc..784faef 100644
--- a/src/mlpack/methods/neighbor_search/allkfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
@@ -18,6 +18,7 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
using namespace mlpack::tree;
+using namespace mlpack::metric;
// Information about the program itself.
PROGRAM_INFO("All K-Furthest-Neighbors",
@@ -126,8 +127,8 @@ int main(int argc, char *argv[])
// Use default kd-tree.
std::vector<size_t> oldFromNewRefs;
- typedef BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<FurthestNeighborSort>> TreeType;
+ typedef KDTree<EuclideanDistance, NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat> TreeType;
// Build trees by hand, so we can save memory: if we pass a tree to
// NeighborSearch, it does not copy the matrix.
@@ -192,12 +193,8 @@ int main(int argc, char *argv[])
Log::Info << "Using R tree for furthest-neighbor calculation." << endl;
// Convenience typedef.
- typedef RectangleTree<
- tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<FurthestNeighborSort>,
- arma::mat> TreeType;
+ typedef RStarTree<EuclideanDistance,
+ NeighborSearchStat<FurthestNeighborSort>, arma::mat> TreeType;
// Build trees by hand, so we can save memory: if we pass a tree to
// NeighborSearch, it does not copy the matrix.
@@ -207,8 +204,8 @@ int main(int argc, char *argv[])
Timer::Stop("tree_building");
Log::Info << "Tree built." << endl;
- typedef NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
- TreeType> AllkFNType;
+ typedef NeighborSearch<FurthestNeighborSort, EuclideanDistance, arma::mat,
+ RStarTree> AllkFNType;
AllkFNType allkfn(&refTree, singleMode);
if (CLI::GetParam<string>("query_file") != "")
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 5a83b4d..c7c2e9c 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -19,6 +19,7 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
using namespace mlpack::tree;
+using namespace mlpack::metric;
// Information about the program itself.
PROGRAM_INFO("All K-Nearest-Neighbors",
@@ -57,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "S");
PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
"(experimental, may be slow).", "c");
-PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
+PARAM_FLAG("r_tree", "If true, use an R*-Tree to perform the search "
"(experimental, may be slow.).", "T");
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
@@ -189,8 +190,8 @@ int main(int argc, char *argv[])
std::vector<size_t> oldFromNewRefs;
// Convenience typedef.
- typedef BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<NearestNeighborSort>> TreeType;
+ typedef KDTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> TreeType;
// Build trees by hand, so we can save memory: if we pass a tree to
// NeighborSearch, it does not copy the matrix.
@@ -255,12 +256,8 @@ int main(int argc, char *argv[])
Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
// Convenience typedef.
- typedef RectangleTree<
- tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat> TreeType;
+ typedef RStarTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
// Build tree by hand in order to apply user options.
Log::Info << "Building reference tree..." << endl;
@@ -269,8 +266,8 @@ int main(int argc, char *argv[])
Timer::Stop("tree_building");
Log::Info << "Tree built." << endl;
- typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- TreeType> AllkNNType;
+ typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat,
+ RStarTree> AllkNNType;
AllkNNType allknn(&refTree, singleMode);
if (CLI::GetParam<string>("query_file") != "")
@@ -307,8 +304,8 @@ int main(int argc, char *argv[])
Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
// Convenience typedef.
- typedef CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- NeighborSearchStat<NearestNeighborSort>> TreeType;
+ typedef StandardCoverTree<metric::EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
// Build our reference tree.
Log::Info << "Building reference tree..." << endl;
@@ -317,7 +314,7 @@ int main(int argc, char *argv[])
Timer::Stop("tree_building");
typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- TreeType> AllkNNType;
+ arma::mat, StandardCoverTree> AllkNNType;
AllkNNType allknn(&refTree, singleMode);
// See if we have query data.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index d7a1040..af3cf3d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -45,14 +45,20 @@ namespace neighbor /** Neighbor-search routines. These include
* @tparam TreeType The tree type to use.
*/
template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<SortPolicy>>,
+ typename MetricType = mlpack::metric::EuclideanDistance,
+ typename MatType = arma::mat,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType = tree::KDTree,
template<typename RuleType> class TraversalType =
- TreeType::template DualTreeTraverser>
+ TreeType<MetricType,
+ NeighborSearchStat<SortPolicy>,
+ MatType>::template DualTreeTraverser>
class NeighborSearch
{
public:
+ //! Convenience typedef.
+ typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
+
/**
* Initialize the NeighborSearch object, passing a reference dataset (this is
* the dataset which is searched). Optionally, perform the computation in
@@ -71,7 +77,7 @@ class NeighborSearch
* dual-tree search).
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const typename TreeType::Mat& referenceSet,
+ NeighborSearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -101,7 +107,7 @@ class NeighborSearch
* opposed to dual-tree computation).
* @param metric Instantiated distance metric.
*/
- NeighborSearch(TreeType* referenceTree,
+ NeighborSearch(Tree* referenceTree,
const bool singleMode = false,
const MetricType metric = MetricType());
@@ -128,7 +134,7 @@ class NeighborSearch
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void Search(const typename TreeType::Mat& querySet,
+ void Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);
@@ -146,7 +152,7 @@ class NeighborSearch
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void Search(TreeType* queryTree,
+ void Search(Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);
@@ -197,9 +203,9 @@ class NeighborSearch
//! Permutations of reference points during tree building.
std::vector<size_t> oldFromNewReferences;
//! Pointer to the root of the reference tree.
- TreeType* referenceTree;
+ Tree* referenceTree;
//! Reference to reference dataset.
- const typename TreeType::Mat& referenceSet;
+ const MatType& referenceSet;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
@@ -219,8 +225,8 @@ class NeighborSearch
}; // class NeighborSearch
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
// Include implementation.
#include "neighbor_search_impl.hpp"
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 48a65fa..57216e1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -16,9 +16,9 @@ namespace mlpack {
namespace neighbor {
//! Call the tree constructor that does mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
TreeType* BuildTree(
- const typename TreeType::Mat& dataset,
+ const MatType& dataset,
std::vector<size_t>& oldFromNew,
typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -28,9 +28,9 @@ TreeType* BuildTree(
}
//! Call the tree constructor that does not do mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
TreeType* BuildTree(
- const typename TreeType::Mat& dataset,
+ const MatType& dataset,
const std::vector<size_t>& /* oldFromNew */,
const typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
@@ -42,15 +42,17 @@ TreeType* BuildTree(
// Construct the object.
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
-NeighborSearch(const typename TreeType::Mat& referenceSetIn,
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
const MetricType metric) :
referenceTree(naive ? NULL :
- BuildTree<TreeType>(referenceSetIn, oldFromNewReferences)),
+ BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
treeOwner(!naive), // False if a tree was passed. If naive, then no trees.
naive(naive),
@@ -65,10 +67,12 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
// Construct the object.
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
-NeighborSearch(TreeType* referenceTree,
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(Tree* referenceTree,
const bool singleMode,
const MetricType metric) :
referenceTree(referenceTree),
@@ -86,9 +90,11 @@ NeighborSearch(TreeType* referenceTree,
// Clean memory.
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
~NeighborSearch()
{
if (treeOwner && referenceTree)
@@ -101,13 +107,15 @@ NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
*/
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
- const typename TreeType::Mat& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(const MatType& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
Timer::Start("computing_neighbors");
@@ -122,7 +130,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
arma::mat* distancePtr = &distances;
// Mapping is only necessary if the tree rearranges points.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
if (!singleMode && !naive)
distancePtr = new arma::mat; // Query indices need to be mapped.
@@ -137,7 +145,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
- typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
if (naive)
{
@@ -157,7 +165,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -174,7 +182,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
// Build the query tree.
Timer::Stop("computing_neighbors");
Timer::Start("tree_building");
- TreeType* queryTree = BuildTree<TreeType>(querySet, oldFromNewQueries);
+ Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
@@ -199,7 +207,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
Timer::Stop("computing_neighbors");
// Map points back to original indices, if necessary.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
if (!singleMode && !naive && treeOwner)
{
@@ -260,18 +268,20 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
- TreeType* queryTree,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(Tree* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
Timer::Start("computing_neighbors");
// Get a reference to the query set.
- const typename TreeType::Mat& querySet = queryTree->Dataset();
+ const MatType& querySet = queryTree->Dataset();
// Make sure we are in dual-tree mode.
if (singleMode || naive)
@@ -281,7 +291,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
// We won't need to map query indices, but will we need to map distances?
arma::Mat<size_t>* neighborPtr = &neighbors;
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
neighborPtr = new arma::Mat<size_t>;
neighborPtr->set_size(k, querySet.n_cols);
@@ -290,7 +300,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
distances.fill(SortPolicy::WorstDistance());
// Create the helper object for the traversal.
- typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
// Create the traverser.
@@ -303,7 +313,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
Timer::Stop("computing_neighbors");
// Do we need to map indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
// We must map reference indices only.
neighbors.set_size(k, querySet.n_cols);
@@ -320,19 +330,21 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
Timer::Start("computing_neighbors");
arma::Mat<size_t>* neighborPtr = &neighbors;
arma::mat* distancePtr = &distances;
- if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+ if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
{
// We will always need to rearrange in this case.
distancePtr = new arma::mat;
@@ -346,7 +358,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
distancePtr->fill(SortPolicy::WorstDistance());
// Create the helper object for the traversal.
- typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
metric, true /* don't return the same point as nearest neighbor */);
@@ -362,7 +374,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
else if (singleMode)
{
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -391,7 +403,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
Timer::Stop("computing_neighbors");
// Do we need to map the reference indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
neighbors.set_size(k, referenceSet.n_cols);
distances.set_size(k, referenceSet.n_cols);
@@ -416,10 +428,12 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
// Return a String of the Object.
template<typename SortPolicy,
typename MetricType,
- typename TreeType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType,
template<typename> class TraversalType>
-std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
- ToString() const
+std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+ TraversalType>::ToString() const
{
std::ostringstream convert;
convert << "NeighborSearch [" << this << "]" << std::endl;
@@ -434,7 +448,7 @@ std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
return convert.str();
}
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list