[mlpack-git] master: Refactor NeighborSearch to new TreeType API. (aa28b5e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:05 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit aa28b5e4fd91e6b21a08309db851f1fc161041bb
Author: ryan <ryan at ratml.org>
Date:   Fri Jul 24 14:26:57 2015 -0400

    Refactor NeighborSearch to new TreeType API.


>---------------------------------------------------------------

aa28b5e4fd91e6b21a08309db851f1fc161041bb
 src/mlpack/methods/neighbor_search/allkfn_main.cpp |  17 ++--
 src/mlpack/methods/neighbor_search/allknn_main.cpp |  25 +++--
 .../methods/neighbor_search/neighbor_search.hpp    |  30 +++---
 .../neighbor_search/neighbor_search_impl.hpp       | 110 ++++++++++++---------
 4 files changed, 98 insertions(+), 84 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/allkfn_main.cpp b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
index 28d79bc..784faef 100644
--- a/src/mlpack/methods/neighbor_search/allkfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
@@ -18,6 +18,7 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Furthest-Neighbors",
@@ -126,8 +127,8 @@ int main(int argc, char *argv[])
     // Use default kd-tree.
     std::vector<size_t> oldFromNewRefs;
 
-    typedef BinarySpaceTree<bound::HRectBound<2>,
-        NeighborSearchStat<FurthestNeighborSort>> TreeType;
+    typedef KDTree<EuclideanDistance, NeighborSearchStat<FurthestNeighborSort>,
+        arma::mat> TreeType;
 
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
@@ -192,12 +193,8 @@ int main(int argc, char *argv[])
     Log::Info << "Using R tree for furthest-neighbor calculation." << endl;
 
     // Convenience typedef.
-    typedef RectangleTree<
-        tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
-            NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-        tree::RStarTreeDescentHeuristic,
-        NeighborSearchStat<FurthestNeighborSort>,
-        arma::mat> TreeType;
+    typedef RStarTree<EuclideanDistance,
+        NeighborSearchStat<FurthestNeighborSort>, arma::mat> TreeType;
 
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
@@ -207,8 +204,8 @@ int main(int argc, char *argv[])
     Timer::Stop("tree_building");
     Log::Info << "Tree built." << endl;
 
-    typedef NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
-        TreeType> AllkFNType;
+    typedef NeighborSearch<FurthestNeighborSort, EuclideanDistance, arma::mat,
+        RStarTree> AllkFNType;
     AllkFNType allkfn(&refTree, singleMode);
 
     if (CLI::GetParam<string>("query_file") != "")
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 5a83b4d..c7c2e9c 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -19,6 +19,7 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Nearest-Neighbors",
@@ -57,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "S");
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
-PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
+PARAM_FLAG("r_tree", "If true, use an R*-Tree to perform the search "
     "(experimental, may be slow.).", "T");
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
@@ -189,8 +190,8 @@ int main(int argc, char *argv[])
       std::vector<size_t> oldFromNewRefs;
 
       // Convenience typedef.
-      typedef BinarySpaceTree<bound::HRectBound<2>,
-          NeighborSearchStat<NearestNeighborSort>> TreeType;
+      typedef KDTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+          arma::mat> TreeType;
 
       // Build trees by hand, so we can save memory: if we pass a tree to
       // NeighborSearch, it does not copy the matrix.
@@ -255,12 +256,8 @@ int main(int argc, char *argv[])
       Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
 
       // Convenience typedef.
-      typedef RectangleTree<
-          tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
-              NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-          tree::RStarTreeDescentHeuristic,
-          NeighborSearchStat<NearestNeighborSort>,
-          arma::mat> TreeType;
+      typedef RStarTree<EuclideanDistance,
+          NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
 
       // Build tree by hand in order to apply user options.
       Log::Info << "Building reference tree..." << endl;
@@ -269,8 +266,8 @@ int main(int argc, char *argv[])
       Timer::Stop("tree_building");
       Log::Info << "Tree built." << endl;
 
-      typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-          TreeType> AllkNNType;
+      typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat,
+          RStarTree> AllkNNType;
       AllkNNType allknn(&refTree, singleMode);
 
       if (CLI::GetParam<string>("query_file") != "")
@@ -307,8 +304,8 @@ int main(int argc, char *argv[])
     Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
 
     // Convenience typedef.
-    typedef CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        NeighborSearchStat<NearestNeighborSort>> TreeType;
+    typedef StandardCoverTree<metric::EuclideanDistance,
+        NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
 
     // Build our reference tree.
     Log::Info << "Building reference tree..." << endl;
@@ -317,7 +314,7 @@ int main(int argc, char *argv[])
     Timer::Stop("tree_building");
 
     typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        TreeType> AllkNNType;
+        arma::mat, StandardCoverTree> AllkNNType;
     AllkNNType allknn(&refTree, singleMode);
 
     // See if we have query data.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index d7a1040..af3cf3d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -45,14 +45,20 @@ namespace neighbor /** Neighbor-search routines.  These include
  * @tparam TreeType The tree type to use.
  */
 template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::SquaredEuclideanDistance,
-         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-             NeighborSearchStat<SortPolicy>>,
+         typename MetricType = mlpack::metric::EuclideanDistance,
+         typename MatType = arma::mat,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType = tree::KDTree,
          template<typename RuleType> class TraversalType =
-             TreeType::template DualTreeTraverser>
+             TreeType<MetricType,
+                      NeighborSearchStat<SortPolicy>,
+                      MatType>::template DualTreeTraverser>
 class NeighborSearch
 {
  public:
+  //! Convenience typedef.
+  typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
+
   /**
    * Initialize the NeighborSearch object, passing a reference dataset (this is
    * the dataset which is searched).  Optionally, perform the computation in
@@ -71,7 +77,7 @@ class NeighborSearch
    *      dual-tree search).
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
+  NeighborSearch(const MatType& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const MetricType metric = MetricType());
@@ -101,7 +107,7 @@ class NeighborSearch
    *      opposed to dual-tree computation).
    * @param metric Instantiated distance metric.
    */
-  NeighborSearch(TreeType* referenceTree,
+  NeighborSearch(Tree* referenceTree,
                  const bool singleMode = false,
                  const MetricType metric = MetricType());
 
@@ -128,7 +134,7 @@ class NeighborSearch
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    */
-  void Search(const typename TreeType::Mat& querySet,
+  void Search(const MatType& querySet,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
@@ -146,7 +152,7 @@ class NeighborSearch
    * @param distances Matrix storing distances of neighbors for each query
    *      point.
    */
-  void Search(TreeType* queryTree,
+  void Search(Tree* queryTree,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
@@ -197,9 +203,9 @@ class NeighborSearch
   //! Permutations of reference points during tree building.
   std::vector<size_t> oldFromNewReferences;
   //! Pointer to the root of the reference tree.
-  TreeType* referenceTree;
+  Tree* referenceTree;
   //! Reference to reference dataset.
-  const typename TreeType::Mat& referenceSet;
+  const MatType& referenceSet;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
@@ -219,8 +225,8 @@ class NeighborSearch
 
 }; // class NeighborSearch
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 // Include implementation.
 #include "neighbor_search_impl.hpp"
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 48a65fa..57216e1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -16,9 +16,9 @@ namespace mlpack {
 namespace neighbor {
 
 //! Call the tree constructor that does mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
 TreeType* BuildTree(
-    const typename TreeType::Mat& dataset,
+    const MatType& dataset,
     std::vector<size_t>& oldFromNew,
     typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -28,9 +28,9 @@ TreeType* BuildTree(
 }
 
 //! Call the tree constructor that does not do mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
 TreeType* BuildTree(
-    const typename TreeType::Mat& dataset,
+    const MatType& dataset,
     const std::vector<size_t>& /* oldFromNew */,
     const typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
@@ -42,15 +42,17 @@ TreeType* BuildTree(
 // Construct the object.
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
-NeighborSearch(const typename TreeType::Mat& referenceSetIn,
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(const MatType& referenceSetIn,
                const bool naive,
                const bool singleMode,
                const MetricType metric) :
     referenceTree(naive ? NULL :
-        BuildTree<TreeType>(referenceSetIn, oldFromNewReferences)),
+        BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
     referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
     treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
     naive(naive),
@@ -65,10 +67,12 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
 // Construct the object.
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
-NeighborSearch(TreeType* referenceTree,
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(Tree* referenceTree,
                const bool singleMode,
                const MetricType metric) :
     referenceTree(referenceTree),
@@ -86,9 +90,11 @@ NeighborSearch(TreeType* referenceTree,
 // Clean memory.
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
     ~NeighborSearch()
 {
   if (treeOwner && referenceTree)
@@ -101,13 +107,15 @@ NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
  */
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
-    const typename TreeType::Mat& querySet,
-    const size_t k,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(const MatType& querySet,
+       const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
 {
   Timer::Start("computing_neighbors");
 
@@ -122,7 +130,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   arma::mat* distancePtr = &distances;
 
   // Mapping is only necessary if the tree rearranges points.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     if (!singleMode && !naive)
       distancePtr = new arma::mat; // Query indices need to be mapped.
@@ -137,7 +145,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
-  typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
 
   if (naive)
   {
@@ -157,7 +165,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
     RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
 
     // Create the traverser.
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -174,7 +182,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
     // Build the query tree.
     Timer::Stop("computing_neighbors");
     Timer::Start("tree_building");
-    TreeType* queryTree = BuildTree<TreeType>(querySet, oldFromNewQueries);
+    Tree* queryTree = BuildTree<MatType, Tree>(querySet, oldFromNewQueries);
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
@@ -199,7 +207,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   Timer::Stop("computing_neighbors");
 
   // Map points back to original indices, if necessary.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     if (!singleMode && !naive && treeOwner)
     {
@@ -260,18 +268,20 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
 
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
-    TreeType* queryTree,
-    const size_t k,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(Tree* queryTree,
+       const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
 {
   Timer::Start("computing_neighbors");
 
   // Get a reference to the query set.
-  const typename TreeType::Mat& querySet = queryTree->Dataset();
+  const MatType& querySet = queryTree->Dataset();
 
   // Make sure we are in dual-tree mode.
   if (singleMode || naive)
@@ -281,7 +291,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   // We won't need to map query indices, but will we need to map distances?
   arma::Mat<size_t>* neighborPtr = &neighbors;
 
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
     neighborPtr = new arma::Mat<size_t>;
 
   neighborPtr->set_size(k, querySet.n_cols);
@@ -290,7 +300,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   distances.fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the traversal.
-  typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
 
   // Create the traverser.
@@ -303,7 +313,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   Timer::Stop("computing_neighbors");
 
   // Do we need to map indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     // We must map reference indices only.
     neighbors.set_size(k, querySet.n_cols);
@@ -320,19 +330,21 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
 
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
-    const size_t k,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Search(const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
 {
   Timer::Start("computing_neighbors");
 
   arma::Mat<size_t>* neighborPtr = &neighbors;
   arma::mat* distancePtr = &distances;
 
-  if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+  if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
   {
     // We will always need to rearrange in this case.
     distancePtr = new arma::mat;
@@ -346,7 +358,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   distancePtr->fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the traversal.
-  typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
       metric, true /* don't return the same point as nearest neighbor */);
 
@@ -362,7 +374,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -391,7 +403,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
   Timer::Stop("computing_neighbors");
 
   // Do we need to map the reference indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     neighbors.set_size(k, referenceSet.n_cols);
     distances.set_size(k, referenceSet.n_cols);
@@ -416,10 +428,12 @@ void NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::Search(
 // Return a String of the Object.
 template<typename SortPolicy,
          typename MetricType,
-         typename TreeType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType,
          template<typename> class TraversalType>
-std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
-    ToString() const
+std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+                           TraversalType>::ToString() const
 {
   std::ostringstream convert;
   convert << "NeighborSearch [" << this << "]" << std::endl;
@@ -434,7 +448,7 @@ std::string NeighborSearch<SortPolicy, MetricType, TreeType, TraversalType>::
   return convert.str();
 }
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list