[mlpack-git] master: Refactor for new TreeType API. Also handle the fact that trees now copy the dataset internally (this will be changed again). (b0f82db)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:17 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit b0f82db41d15d7f18b2ac6232448cb5babb720e5
Author: ryan <ryan at ratml.org>
Date:   Sun Jul 26 23:07:08 2015 -0400

    Refactor for new TreeType API.
    Also handle the fact that trees now copy the dataset internally (this will be
    changed again).


>---------------------------------------------------------------

b0f82db41d15d7f18b2ac6232448cb5babb720e5
 src/mlpack/methods/emst/dtb.hpp       | 12 ++++--------
 src/mlpack/methods/emst/dtb_impl.hpp  | 21 ++++-----------------
 src/mlpack/methods/emst/emst_main.cpp |  2 +-
 3 files changed, 9 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/methods/emst/dtb.hpp b/src/mlpack/methods/emst/dtb.hpp
index 898e586..3560ccc 100644
--- a/src/mlpack/methods/emst/dtb.hpp
+++ b/src/mlpack/methods/emst/dtb.hpp
@@ -81,13 +81,12 @@ class DualTreeBoruvka
   typedef TreeType<MetricType, DTBStat, MatType> Tree;
 
  private:
-  //! Copy of the data (if necessary).
-  MatType dataCopy;
-  //! Reference to the data (this is what should be used for accessing data).
-  const MatType& data;
-
+  //! Permutations of points during tree building.
+  std::vector<size_t> oldFromNew;
   //! Pointer to the root of the tree.
   Tree* tree;
+  //! Reference to the data (this is what should be used for accessing data).
+  const MatType& data;
   //! Indicates whether or not we "own" the tree.
   bool ownTree;
 
@@ -100,8 +99,6 @@ class DualTreeBoruvka
   //! Connections.
   UnionFind connections;
 
-  //! Permutations of points during tree building.
-  std::vector<size_t> oldFromNew;
   //! List of edge nodes.
   arma::Col<size_t> neighborsInComponent;
   //! List of edge nodes.
@@ -155,7 +152,6 @@ class DualTreeBoruvka
    * @param dataset Dataset corresponding to the pre-built tree.
    */
   DualTreeBoruvka(Tree* tree,
-                  const MatType& dataset,
                   const MetricType metric = MetricType());
 
   /**
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 68580b3..d8402e2 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -49,27 +49,15 @@ DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
     const MatType& dataset,
     const bool naive,
     const MetricType metric) :
-    data((tree::TreeTraits<Tree>::RearrangesDataset && !naive) ? dataCopy :
-        dataset),
+    tree(naive ? NULL : BuildTree<MatType, Tree>(const_cast<MatType&>(dataset),
+        oldFromNew)),
+    data(naive ? dataset : tree->Dataset()),
     ownTree(!naive),
     naive(naive),
     connections(dataset.n_cols),
     totalDist(0.0),
     metric(metric)
 {
-  Timer::Start("emst/tree_building");
-
-  if (!naive)
-  {
-    // Copy the dataset, if it will be modified during tree construction.
-    if (tree::TreeTraits<Tree>::RearrangesDataset)
-      dataCopy = dataset;
-
-    tree = BuildTree<MatType, Tree>(const_cast<MatType&>(data), oldFromNew);
-  }
-
-  Timer::Stop("emst/tree_building");
-
   edges.reserve(data.n_cols - 1); // Set size.
 
   neighborsInComponent.set_size(data.n_cols);
@@ -85,10 +73,9 @@ template<
         class TreeType>
 DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
     Tree* tree,
-    const MatType& dataset,
     const MetricType metric) :
-    data(dataset),
     tree(tree),
+    data(tree->Dataset()),
     ownTree(false),
     naive(false),
     connections(data.n_cols),
diff --git a/src/mlpack/methods/emst/emst_main.cpp b/src/mlpack/methods/emst/emst_main.cpp
index 6f51673..828693e 100644
--- a/src/mlpack/methods/emst/emst_main.cpp
+++ b/src/mlpack/methods/emst/emst_main.cpp
@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
     metric::LMetric<2, true> metric;
     Timer::Stop("tree_building");
 
-    DualTreeBoruvka<> dtb(&tree, dataPoints, metric);
+    DualTreeBoruvka<> dtb(&tree, metric);
 
     // Run the DTB algorithm.
     Log::Info << "Calculating minimum spanning tree." << endl;



More information about the mlpack-git mailing list