[mlpack-svn] r16760 - mlpack/trunk/src/mlpack/methods/emst

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 3 16:46:27 EDT 2014


Author: rcurtin
Date: Thu Jul  3 16:46:27 2014
New Revision: 16760

Log:
Refactor tree construction so that arbitrary tree types can be constructed.


Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
   mlpack/trunk/src/mlpack/methods/emst/dtb_stat.hpp

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp	Thu Jul  3 16:46:27 2014
@@ -12,36 +12,30 @@
 namespace mlpack {
 namespace emst {
 
-// DTBStat
-
-/**
- * A generic initializer.
- */
-inline DTBStat::DTBStat() :
-    maxNeighborDistance(DBL_MAX),
-    minNeighborDistance(DBL_MAX),
-    bound(DBL_MAX),
-    componentMembership(-1)
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
 {
-  // Nothing to do.
+  return new TreeType(dataset, oldFromNew);
 }
 
-/**
- * An initializer for leaves.
- */
+//! Call the tree constructor that does not do mapping.
 template<typename TreeType>
-DTBStat::DTBStat(const TreeType& node) :
-    maxNeighborDistance(DBL_MAX),
-    minNeighborDistance(DBL_MAX),
-    bound(DBL_MAX),
-    componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
-        node.Point(0) : -1)
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
 {
-  // Nothing to do.
+  return new TreeType(dataset);
 }
 
-// DualTreeBoruvka
-
 /**
  * Takes in a reference to the data set.  Copies the data, builds the tree,
  * and initializes all of the member variables.
@@ -51,20 +45,24 @@
     const typename TreeType::Mat& dataset,
     const bool naive,
     const MetricType metric) :
-    dataCopy(dataset),
-    data(dataCopy), // The reference points to our copy of the data.
+    data((tree::TreeTraits<TreeType>::RearrangesDataset && !naive) ? dataCopy : dataset),
     ownTree(!naive),
     naive(naive),
-    connections(data.n_cols),
+    connections(dataset.n_cols),
     totalDist(0.0),
     metric(metric)
 {
   Timer::Start("emst/tree_building");
 
-  // Default leaf size is 1; this gives the best pruning, empirically.  Use
-  // leaf_size = 1 unless space is a big concern.
   if (!naive)
-    tree = new TreeType(dataCopy, oldFromNew);
+  {
+    // Copy the dataset, if it will be modified during tree construction.
+    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      dataCopy = dataset;
+
+    tree = BuildTree<TreeType>(const_cast<typename TreeType::Mat&>(data),
+        oldFromNew);
+  }
 
   Timer::Stop("emst/tree_building");
 
@@ -89,7 +87,7 @@
     totalDist(0.0),
     metric(metric)
 {
-  edges.reserve(data.n_cols - 1); // fill with EdgePairs
+  edges.reserve(data.n_cols - 1); // Fill with EdgePairs.
 
   neighborsInComponent.set_size(data.n_cols);
   neighborsOutComponent.set_size(data.n_cols);
@@ -205,7 +203,7 @@
   results.set_size(3, edges.size());
 
   // Need to unpermute the point labels.
-  if (!naive && ownTree)
+  if (!naive && ownTree && tree::TreeTraits<TreeType>::RearrangesDataset)
   {
     for (size_t i = 0; i < (data.n_cols - 1); i++)
     {
@@ -248,39 +246,34 @@
 template<typename MetricType, typename TreeType>
 void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
 {
+  // Reset the statistic information.
   tree->Stat().MaxNeighborDistance() = DBL_MAX;
   tree->Stat().MinNeighborDistance() = DBL_MAX;
   tree->Stat().Bound() = DBL_MAX;
 
-  if (!tree->IsLeaf())
-  {
-    CleanupHelper(tree->Left());
-    CleanupHelper(tree->Right());
+  // Recurse into all children.
+  for (size_t i = 0; i < tree->NumChildren(); ++i)
+    CleanupHelper(&tree->Child(i));
+
+  // Get the component of the first child or point.  Then we will check to see
+  // if all other components of children and points are the same.
+  const int component = (tree->NumChildren() != 0) ?
+      tree->Child(0).Stat().ComponentMembership() :
+      connections.Find(tree->Point(0));
+
+  // Check components of children.
+  for (size_t i = 0; i < tree->NumChildren(); ++i)
+    if (tree->Child(i).Stat().ComponentMembership() != component)
+      return;
+
+  // Check components of points.
+  for (size_t i = 0; i < tree->NumPoints(); ++i)
+    if (connections.Find(tree->Point(i)) != component)
+      return;
 
-    if ((tree->Left()->Stat().ComponentMembership() >= 0)
-        && (tree->Left()->Stat().ComponentMembership() ==
-            tree->Right()->Stat().ComponentMembership()))
-    {
-      tree->Stat().ComponentMembership() =
-          tree->Left()->Stat().ComponentMembership();
-    }
-  }
-  else
-  {
-    size_t newMembership = connections.Find(tree->Begin());
-
-    for (size_t i = tree->Begin(); i < tree->End(); ++i)
-    {
-      if (newMembership != connections.Find(i))
-      {
-        newMembership = -1;
-        Log::Assert(tree->Stat().ComponentMembership() < 0);
-        return;
-      }
-    }
-    tree->Stat().ComponentMembership() = newMembership;
-  }
-} // CleanupHelper
+  // If we made it this far, all components are the same.
+  tree->Stat().ComponentMembership() = component;
+}
 
 /**
  * The values stored in the tree must be reset on each iteration.
@@ -289,14 +282,10 @@
 void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
 {
   for (size_t i = 0; i < data.n_cols; i++)
-  {
     neighborsDistances[i] = DBL_MAX;
-  }
 
   if (!naive)
-  {
     CleanupHelper(tree);
-  }
 }
 
 // convert the object to a string
@@ -304,12 +293,12 @@
 std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
 {
   std::ostringstream convert;
-  convert << "Dual Tree Boruvka [" << this << "]" << std::endl;
+  convert << "DualTreeBoruvka [" << this << "]" << std::endl;
   convert << "  Data: " << data.n_rows << "x" << data.n_cols <<std::endl;
   convert << "  Total Distance: " << totalDist <<std::endl;
   convert << "  Naive: " << naive << std::endl;
   convert << "  Metric: " << std::endl;
-  convert << mlpack::util::Indent(metric.ToString(),2);
+  convert << util::Indent(metric.ToString(), 2);
   convert << std::endl;
   return convert.str();
 }

Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_stat.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_stat.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_stat.hpp	Thu Jul  3 16:46:27 2014
@@ -41,7 +41,11 @@
    * A generic initializer.  Sets the maximum neighbor distance to its default,
    * and the component membership to -1 (no component).
    */
-  DTBStat();
+  DTBStat() :
+      maxNeighborDistance(DBL_MAX),
+      minNeighborDistance(DBL_MAX),
+      bound(DBL_MAX),
+      componentMembership(-1) { }
 
   /**
    * This is called when a node is finished initializing.  We set the maximum
@@ -51,7 +55,13 @@
    * @param node Node that has been finished.
    */
   template<typename TreeType>
-  DTBStat(const TreeType& node);
+  DTBStat(const TreeType& node) :
+      maxNeighborDistance(DBL_MAX),
+      minNeighborDistance(DBL_MAX),
+      bound(DBL_MAX),
+      componentMembership(
+          ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
+            node.Point(0) : -1) { }
 
   //! Get the maximum neighbor distance.
   double MaxNeighborDistance() const { return maxNeighborDistance; }



More information about the mlpack-svn mailing list