[mlpack-git] master: Refactor DTB to use new TreeType API. (b37e976)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:35 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit b37e976adbb17d94a6f7fb5b76696faaa23b0697
Author: ryan <ryan at ratml.org>
Date:   Wed Jul 22 23:21:44 2015 -0400

    Refactor DTB to use new TreeType API.


>---------------------------------------------------------------

b37e976adbb17d94a6f7fb5b76696faaa23b0697
 src/mlpack/methods/emst/dtb.hpp       |  28 +++++---
 src/mlpack/methods/emst/dtb_impl.hpp  | 127 +++++++++++++++++++++++-----------
 src/mlpack/methods/emst/emst_main.cpp |   5 +-
 3 files changed, 105 insertions(+), 55 deletions(-)

diff --git a/src/mlpack/methods/emst/dtb.hpp b/src/mlpack/methods/emst/dtb.hpp
index 667da69..f706b13 100644
--- a/src/mlpack/methods/emst/dtb.hpp
+++ b/src/mlpack/methods/emst/dtb.hpp
@@ -70,19 +70,25 @@ namespace emst /** Euclidean Minimum Spanning Trees. */ {
  * @tparam TreeType Type of tree to use.  Should use DTBStat as a statistic.
  */
 template<
-  typename MetricType = metric::EuclideanDistance,
-  typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
+    typename MetricType = metric::EuclideanDistance,
+    typename MatType = arma::mat,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType = tree::KDTree
 >
 class DualTreeBoruvka
 {
+ public:
+  //! Convenience typedef.
+  typedef TreeType<MetricType, DTBStat, MatType> Tree;
+
  private:
   //! Copy of the data (if necessary).
-  typename TreeType::Mat dataCopy;
+  MatType dataCopy;
   //! Reference to the data (this is what should be used for accessing data).
-  const typename TreeType::Mat& data;
+  const MatType& data;
 
   //! Pointer to the root of the tree.
-  TreeType* tree;
+  Tree* tree;
   //! Indicates whether or not we "own" the tree.
   bool ownTree;
 
@@ -128,7 +134,7 @@ class DualTreeBoruvka
    * @param naive Whether the computation should be done in O(n^2) naive mode.
    * @param leafSize The leaf size to be used during tree construction.
    */
-  DualTreeBoruvka(const typename TreeType::Mat& dataset,
+  DualTreeBoruvka(const MatType& dataset,
                   const bool naive = false,
                   const MetricType metric = MetricType());
 
@@ -149,8 +155,8 @@ class DualTreeBoruvka
    * @param tree Pre-built tree.
    * @param dataset Dataset corresponding to the pre-built tree.
    */
-  DualTreeBoruvka(TreeType* tree,
-                  const typename TreeType::Mat& dataset,
+  DualTreeBoruvka(Tree* tree,
+                  const MatType& dataset,
                   const MetricType metric = MetricType());
 
   /**
@@ -194,7 +200,7 @@ class DualTreeBoruvka
    * This function resets the values in the nodes of the tree nearest neighbor
    * distance, and checks for fully connected nodes.
    */
-  void CleanupHelper(TreeType* tree);
+  void CleanupHelper(Tree* tree);
 
   /**
    * The values stored in the tree must be reset on each iteration.
@@ -203,8 +209,8 @@ class DualTreeBoruvka
 
 }; // class DualTreeBoruvka
 
-}; // namespace emst
-}; // namespace mlpack
+} // namespace emst
+} // namespace mlpack
 
 #include "dtb_impl.hpp"
 
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 1a14ad7..68580b3 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -13,9 +13,9 @@ namespace mlpack {
 namespace emst {
 
 //! Call the tree constructor that does mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
 TreeType* BuildTree(
-    typename TreeType::Mat& dataset,
+    MatType& dataset,
     std::vector<size_t>& oldFromNew,
     typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -25,9 +25,9 @@ TreeType* BuildTree(
 }
 
 //! Call the tree constructor that does not do mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
 TreeType* BuildTree(
-    const typename TreeType::Mat& dataset,
+    const MatType& dataset,
     const std::vector<size_t>& /* oldFromNew */,
     const typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
@@ -40,12 +40,17 @@ TreeType* BuildTree(
  * Takes in a reference to the data set.  Copies the data, builds the tree,
  * and initializes all of the member variables.
  */
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
-    const typename TreeType::Mat& dataset,
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
+    const MatType& dataset,
     const bool naive,
     const MetricType metric) :
-    data((tree::TreeTraits<TreeType>::RearrangesDataset && !naive) ? dataCopy : dataset),
+    data((tree::TreeTraits<Tree>::RearrangesDataset && !naive) ? dataCopy :
+        dataset),
     ownTree(!naive),
     naive(naive),
     connections(dataset.n_cols),
@@ -57,11 +62,10 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
   if (!naive)
   {
     // Copy the dataset, if it will be modified during tree construction.
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    if (tree::TreeTraits<Tree>::RearrangesDataset)
       dataCopy = dataset;
 
-    tree = BuildTree<TreeType>(const_cast<typename TreeType::Mat&>(data),
-        oldFromNew);
+    tree = BuildTree<MatType, Tree>(const_cast<MatType&>(data), oldFromNew);
   }
 
   Timer::Stop("emst/tree_building");
@@ -72,12 +76,16 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
   neighborsOutComponent.set_size(data.n_cols);
   neighborsDistances.set_size(data.n_cols);
   neighborsDistances.fill(DBL_MAX);
-} // Constructor
+}
 
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
-    TreeType* tree,
-    const typename TreeType::Mat& dataset,
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
+    Tree* tree,
+    const MatType& dataset,
     const MetricType metric) :
     data(dataset),
     tree(tree),
@@ -95,8 +103,12 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
   neighborsDistances.fill(DBL_MAX);
 }
 
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::~DualTreeBoruvka()
 {
   if (ownTree)
     delete tree;
@@ -106,14 +118,19 @@ DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
  * Iteratively find the nearest neighbor of each component until the MST is
  * complete.
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::ComputeMST(
+    arma::mat& results)
 {
   Timer::Start("emst/mst_computation");
 
   totalDist = 0; // Reset distance.
 
-  typedef DTBRules<MetricType, TreeType> RuleType;
+  typedef DTBRules<MetricType, Tree> RuleType;
   RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
                  neighborsOutComponent, metric);
   while (edges.size() < (data.n_cols - 1))
@@ -127,7 +144,7 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
     }
     else
     {
-      typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+      typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
       traverser.Traverse(*tree, *tree);
     }
 
@@ -154,10 +171,15 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
 /**
  * Adds a single edge to the edge list
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
-                                        const size_t e2,
-                                        const double distance)
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::AddEdge(
+    const size_t e1,
+    const size_t e2,
+    const double distance)
 {
   Log::Assert((distance >= 0.0),
       "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
@@ -166,13 +188,17 @@ void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
     edges.push_back(EdgePair(e1, e2, distance));
   else
     edges.push_back(EdgePair(e2, e1, distance));
-} // AddEdge
+}
 
 /**
  * Adds all the edges found in one iteration to the list of neighbors.
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::AddAllEdges()
 {
   for (size_t i = 0; i < data.n_cols; i++)
   {
@@ -188,13 +214,18 @@ void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
       connections.Union(inEdge, outEdge);
     }
   }
-} // AddAllEdges
+}
 
 /**
  * Unpermute the edge list (if necessary) and output it to results.
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::EmitResults(
+    arma::mat& results)
 {
   // Sort the edges.
   std::sort(edges.begin(), edges.end(), SortFun);
@@ -203,7 +234,7 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
   results.set_size(3, edges.size());
 
   // Need to unpermute the point labels.
-  if (!naive && ownTree && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (!naive && ownTree && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     for (size_t i = 0; i < (data.n_cols - 1); i++)
     {
@@ -237,14 +268,18 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
       results(2, i) = edges[i].Distance();
     }
   }
-} // EmitResults
+}
 
 /**
  * This function resets the values in the nodes of the tree nearest neighbor
  * distance and checks for fully connected nodes.
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::CleanupHelper(Tree* tree)
 {
   // Reset the statistic information.
   tree->Stat().MaxNeighborDistance() = DBL_MAX;
@@ -278,8 +313,12 @@ void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
 /**
  * The values stored in the tree must be reset on each iteration.
  */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::Cleanup()
 {
   for (size_t i = 0; i < data.n_cols; i++)
     neighborsDistances[i] = DBL_MAX;
@@ -289,8 +328,12 @@ void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
 }
 
 // convert the object to a string
-template<typename MetricType, typename TreeType>
-std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
+template<
+    typename MetricType,
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType>
+std::string DualTreeBoruvka<MetricType, MatType, TreeType>::ToString() const
 {
   std::ostringstream convert;
   convert << "DualTreeBoruvka [" << this << "]" << std::endl;
@@ -303,7 +346,7 @@ std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
   return convert.str();
 }
 
-}; // namespace emst
-}; // namespace mlpack
+} // namespace emst
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/emst/emst_main.cpp b/src/mlpack/methods/emst/emst_main.cpp
index 56a9c4b..6f51673 100644
--- a/src/mlpack/methods/emst/emst_main.cpp
+++ b/src/mlpack/methods/emst/emst_main.cpp
@@ -44,6 +44,7 @@ PARAM_INT("leaf_size", "Leaf size in the kd-tree.  One-element leaves give the "
 using namespace mlpack;
 using namespace mlpack::emst;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 using namespace std;
 
 int main(int argc, char* argv[])
@@ -86,8 +87,8 @@ int main(int argc, char* argv[])
 
     Timer::Start("tree_building");
     std::vector<size_t> oldFromNew;
-    tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> tree(dataPoints,
-        oldFromNew, leafSize);
+    KDTree<EuclideanDistance, DTBStat, arma::mat> tree(dataPoints, oldFromNew,
+        leafSize);
     metric::LMetric<2, true> metric;
     Timer::Stop("tree_building");
 



More information about the mlpack-git mailing list