[mlpack-git] master: Refactor for new TreeType API. (5ee25ba)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:36 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit 5ee25baa085be32d7244f52ad7f742906112158e
Author: ryan <ryan at ratml.org>
Date:   Sun Jul 26 23:07:37 2015 -0400

    Refactor for new TreeType API.


>---------------------------------------------------------------

5ee25baa085be32d7244f52ad7f742906112158e
 src/mlpack/methods/range_search/range_search.hpp   |  32 +++---
 .../methods/range_search/range_search_impl.hpp     | 109 ++++++++++++---------
 .../methods/range_search/range_search_main.cpp     |   9 +-
 3 files changed, 89 insertions(+), 61 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index b168cfe..04d3071 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -21,13 +21,21 @@ namespace range /** Range-search routines. */ {
  * is implemented in the style of a generalized tree-independent dual-tree
  * algorithm; for more details on the actual algorithm, see the RangeSearchRules
  * class.
+ *
+ * @tparam MetricType Metric to use for range search calculations.
+ * @tparam MatType Type of data to use.
+ * @tparam TreeType Type of tree to use; must satisfy the TreeType policy API.
  */
-template<typename MetricType = mlpack::metric::EuclideanDistance,
-         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-                                                   RangeSearchStat> >
+template<typename MetricType = metric::EuclideanDistance,
+         typename MatType = arma::mat,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType = tree::KDTree>
 class RangeSearch
 {
  public:
+  //! Convenience typedef.
+  typedef TreeType<MetricType, RangeSearchStat, MatType> Tree;
+
   /**
    * Initialize the RangeSearch object with a given reference dataset (this is
    * the dataset which is searched).  Optionally, perform the computation in
@@ -44,7 +52,7 @@ class RangeSearch
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(const typename TreeType::Mat& referenceSet,
+  RangeSearch(const MatType& referenceSet,
               const bool naive = false,
               const bool singleMode = false,
               const MetricType metric = MetricType());
@@ -73,7 +81,7 @@ class RangeSearch
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
-  RangeSearch(TreeType* referenceTree,
+  RangeSearch(Tree* referenceTree,
               const bool singleMode = false,
               const MetricType metric = MetricType());
 
@@ -110,7 +118,7 @@ class RangeSearch
    * @param distances Object which will hold the list of distances for each
    *      point which fell into the given range, for each query point.
    */
-  void Search(const typename TreeType::Mat& querySet,
+  void Search(const MatType& querySet,
               const math::Range& range,
               std::vector<std::vector<size_t>>& neighbors,
               std::vector<std::vector<double>>& distances);
@@ -151,7 +159,7 @@ class RangeSearch
    * @param distances Object which will hold the list of distances for each
    *      point which fell into the given range, for each query point.
    */
-  void Search(TreeType* queryTree,
+  void Search(Tree* queryTree,
               const math::Range& range,
               std::vector<std::vector<size_t>>& neighbors,
               std::vector<std::vector<double>>& distances);
@@ -195,11 +203,11 @@ class RangeSearch
 
  private:
   //! Copy of reference matrix; used when a tree is built internally.
-  typename TreeType::Mat referenceCopy;
+  MatType referenceCopy;
   //! Reference set (data should be accessed using this).
-  const typename TreeType::Mat& referenceSet;
+  const MatType& referenceSet;
   //! Reference tree.
-  TreeType* referenceTree;
+  Tree* referenceTree;
   //! Mappings to old reference indices (used when this object builds trees).
   std::vector<size_t> oldFromNewReferences;
 
@@ -215,8 +223,8 @@ class RangeSearch
   MetricType metric;
 };
 
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
 
 // Include implementation.
 #include "range_search_impl.hpp"
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 97e7e20..a7ad027 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -39,13 +39,16 @@ TreeType* BuildTree(
   return new TreeType(dataset);
 }
 
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
-    const typename TreeType::Mat& referenceSetIn,
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+    const MatType& referenceSetIn,
     const bool naive,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+    referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
         ? referenceCopy : referenceSetIn),
     referenceTree(NULL),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
@@ -60,23 +63,25 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
   if (!naive)
   {
     // Copy the dataset, if it will be modified during tree building.
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    if (tree::TreeTraits<Tree>::RearrangesDataset)
       referenceCopy = referenceSetIn;
 
     // The const_cast is safe; if RearrangesDataset == false, then it'll be
     // casted back to const anyway, and if not, referenceSet points to
     // referenceCopy, which isn't const.
-    referenceTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(referenceSet),
+    referenceTree = BuildTree<Tree>(const_cast<MatType&>(referenceSet),
         oldFromNewReferences);
   }
 
   Timer::Stop("range_search/tree_building");
 }
 
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
-    TreeType* referenceTree,
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+    Tree* referenceTree,
     const bool singleMode,
     const MetricType metric) :
     referenceSet(referenceTree->Dataset()),
@@ -89,16 +94,22 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
   // Nothing else to initialize.
 }
 
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::~RangeSearch()
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
 }
 
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
-    const typename TreeType::Mat& querySet,
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
+    const MatType& querySet,
     const math::Range& range,
     std::vector<std::vector<size_t>>& neighbors,
     std::vector<std::vector<double>>& distances)
@@ -110,14 +121,13 @@ void RangeSearch<MetricType, TreeType>::Search(
 
   // If we will be building a tree and it will modify the query set, make a copy
   // of the dataset.
-  typename TreeType::Mat queryCopy;
+  MatType queryCopy;
   const bool needsCopy = (!naive && !singleMode &&
-      tree::TreeTraits<TreeType>::RearrangesDataset);
+      tree::TreeTraits<Tree>::RearrangesDataset);
   if (needsCopy)
     queryCopy = querySet;
 
-  const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
-      querySet;
+  const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
 
   // If we have built the trees ourselves, then we will have to map all the
   // indices back to their original indices when this computation is finished.
@@ -127,7 +137,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   std::vector<std::vector<double>>* distancePtr = &distances;
 
   // Mapping is only necessary if the tree rearranges points.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     // Query indices only need to be mapped if we are building the query tree
     // ourselves.
@@ -147,7 +157,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   distancePtr->resize(querySet.n_cols);
 
   // Create the helper object for the traversal.
-  typedef RangeSearchRules<MetricType, TreeType> RuleType;
+  typedef RangeSearchRules<MetricType, Tree> RuleType;
   RuleType rules(referenceSet, querySetRef, range, *neighborPtr, *distancePtr,
       metric);
 
@@ -161,7 +171,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -172,13 +182,13 @@ void RangeSearch<MetricType, TreeType>::Search(
     // Build the query tree.
     Timer::Stop("range_search/computing_neighbors");
     Timer::Start("range_search/tree_building");
-    TreeType* queryTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+    Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+        oldFromNewQueries);
     Timer::Stop("range_search/tree_building");
     Timer::Start("range_search/computing_neighbors");
 
     // Create the traverser.
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*queryTree, *referenceTree);
 
@@ -189,7 +199,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   Timer::Stop("range_search/computing_neighbors");
 
   // Map points back to original indices, if necessary.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     if (!singleMode && !naive && treeOwner)
     {
@@ -255,9 +265,12 @@ void RangeSearch<MetricType, TreeType>::Search(
   }
 }
 
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
-    TreeType* queryTree,
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
+    Tree* queryTree,
     const math::Range& range,
     std::vector<std::vector<size_t>>& neighbors,
     std::vector<std::vector<double>>& distances)
@@ -265,7 +278,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   Timer::Start("range_search/computing_neighbors");
 
   // Get a reference to the query set.
-  const typename TreeType::Mat& querySet = queryTree->Dataset();
+  const MatType& querySet = queryTree->Dataset();
 
   // Make sure we are in dual-tree mode.
   if (singleMode || naive)
@@ -275,7 +288,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   // We won't need to map query indices, but will we need to map distances?
   std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
 
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
     neighborPtr = new std::vector<std::vector<size_t>>;
 
   // Resize each vector.
@@ -285,19 +298,19 @@ void RangeSearch<MetricType, TreeType>::Search(
   distances.resize(querySet.n_cols);
 
   // Create the helper object for the traversal.
-  typedef RangeSearchRules<MetricType, TreeType> RuleType;
+  typedef RangeSearchRules<MetricType, Tree> RuleType;
   RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
       distances, metric);
 
   // Create the traverser.
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
   traverser.Traverse(*queryTree, *referenceTree);
 
   Timer::Stop("range_search/computing_neighbors");
 
   // Do we need to map indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     // We must map reference indices only.
     neighbors.clear();
@@ -315,8 +328,11 @@ void RangeSearch<MetricType, TreeType>::Search(
   }
 }
 
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Search(
     const math::Range& range,
     std::vector<std::vector<size_t>>& neighbors,
     std::vector<std::vector<double>>& distances)
@@ -327,7 +343,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
   std::vector<std::vector<double>>* distancePtr = &distances;
 
-  if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+  if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
   {
     // We will always need to rearrange in this case.
     distancePtr = new std::vector<std::vector<double>>;
@@ -341,7 +357,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   distancePtr->resize(referenceSet.n_cols);
 
   // Create the helper object for the traversal.
-  typedef RangeSearchRules<MetricType, TreeType> RuleType;
+  typedef RangeSearchRules<MetricType, Tree> RuleType;
   RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
       metric, true /* don't return the query point in the results */);
 
@@ -355,7 +371,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -364,7 +380,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   else // Dual-tree recursion.
   {
     // Create the traverser.
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*referenceTree, *referenceTree);
   }
@@ -372,7 +388,7 @@ void RangeSearch<MetricType, TreeType>::Search(
   Timer::Stop("range_search/computing_neighbors");
 
   // Do we need to map the reference indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     neighbors.clear();
     neighbors.resize(referenceSet.n_cols);
@@ -399,8 +415,11 @@ void RangeSearch<MetricType, TreeType>::Search(
   }
 }
 
-template<typename MetricType, typename TreeType>
-std::string RangeSearch<MetricType, TreeType>::ToString() const
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+std::string RangeSearch<MetricType, MatType, TreeType>::ToString() const
 {
   std::ostringstream convert;
   convert << "Range Search  [" << this << "]" << std::endl;
@@ -413,7 +432,7 @@ std::string RangeSearch<MetricType, TreeType>::ToString() const
   return convert.str();
 }
 
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 31eba17..783132c 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -16,6 +16,7 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::range;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 
 // Information about the program itself.
 PROGRAM_INFO("Range Search",
@@ -66,9 +67,9 @@ PARAM_FLAG("cover_tree", "If true, use a cover tree for range searching "
     "(instead of a kd-tree).", "c");
 
 typedef RangeSearch<> RSType;
-typedef CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
-    RangeSearchStat> CoverTreeType;
-typedef RangeSearch<metric::EuclideanDistance, CoverTreeType> RSCoverType;
+typedef CoverTree<EuclideanDistance, RangeSearchStat> CoverTreeType;
+typedef RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
+    RSCoverType;
 
 int main(int argc, char *argv[])
 {
@@ -162,7 +163,7 @@ int main(int argc, char *argv[])
   }
   else
   {
-    typedef BinarySpaceTree<bound::HRectBound<2>, RangeSearchStat> TreeType;
+    typedef KDTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
 
     // Track mappings.
     Log::Info << "Building reference tree..." << endl;



More information about the mlpack-git mailing list