[mlpack-git] master: Coalesce and decoalesce the tree. (928fd6d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:33 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 928fd6dde0a6b8187cc8455e3cd4a8ac0cb16c55
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Feb 18 09:57:28 2015 -0500

    Coalesce and decoalesce the tree.


>---------------------------------------------------------------

928fd6dde0a6b8187cc8455e3cd4a8ac0cb16c55
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  3 ++
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 69 ++++++++++++++++++++++++++
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   | 22 +++++++-
 3 files changed, 92 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 26c90df..bf83533 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -104,6 +104,9 @@ class DTNNKMeans
                         arma::Col<size_t>& newCounts,
                         std::vector<size_t>& oldFromNewCentroids,
                         arma::mat& centroids);
+
+  void CoalesceTree(TreeType& node, const size_t child = 0);
+  void DecoalesceTree(TreeType& node);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 985c359..0ab1bd4 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -121,11 +121,19 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
       traverser(rules);
 
+  Timer::Start("tree_mod");
+  CoalesceTree(*tree);
+  Timer::Stop("tree_mod");
+
   // Set the number of pruned centroids in the root to 0.
   tree->Stat().Pruned() = 0;
   traverser.Traverse(*tree, *centroidTree);
   distanceCalculations += rules.BaseCases() + rules.Scores();
 
+  Timer::Start("tree_mod");
+  DecoalesceTree(*tree);
+  Timer::Stop("tree_mod");
+
   // Now we need to extract the clusters.
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
@@ -383,6 +391,67 @@ assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
   }
 }
 
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
+    TreeType& node,
+    const size_t child /* Which child are we? */)
+{
+  // If one of the two children is pruned, we hide this node.
+  // This assumes the BinarySpaceTree.  (bad Ryan! bad!)
+  if (node.NumChildren() == 0)
+    return; // We can't do anything.
+
+  // If this is the root node, we can't coalesce.
+  if (node.Parent() != NULL)
+  {
+    if (node.Child(0).Stat().StaticPruned() &&
+        !node.Child(1).Stat().StaticPruned())
+    {
+      CoalesceTree(node.Child(1), 1);
+
+      // Link the right child to the parent.
+      node.Child(1).Parent() = node.Parent();
+      node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+    }
+    else if (!node.Child(0).Stat().StaticPruned() &&
+             node.Child(1).Stat().StaticPruned())
+    {
+      CoalesceTree(node.Child(0), 0);
+
+      // Link the left child to the parent.
+      node.Child(0).Parent() = node.Parent();
+      node.Parent()->ChildPtr(child) = node.ChildPtr(0);
+
+    }
+    else if (!node.Child(0).Stat().StaticPruned() &&
+             !node.Child(1).Stat().StaticPruned())
+    {
+      // The conditional is probably not necessary.
+      CoalesceTree(node.Child(0), 0);
+      CoalesceTree(node.Child(1), 1);
+    }
+  }
+  else
+  {
+    CoalesceTree(node.Child(0), 0);
+    CoalesceTree(node.Child(1), 1);
+  }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
+{
+  node.Parent() = (TreeType*) node.Stat().TrueParent();
+  node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
+  node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+
+  if (node.NumChildren() > 0)
+  {
+    DecoalesceTree(node.Child(0));
+    DecoalesceTree(node.Child(1));
+  }
+}
+
 } // namespace kmeans
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 2601378..51f4f60 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -25,7 +25,10 @@ class DTNNStatistic : public
       staticPruned(false),
       staticUpperBoundMovement(0.0),
       staticLowerBoundMovement(0.0),
-      centroid()
+      centroid(),
+      trueParent(NULL),
+      trueLeft(NULL),
+      trueRight(NULL)
   {
     // Nothing to do.
   }
@@ -39,7 +42,10 @@ class DTNNStatistic : public
       pruned(size_t(-1)),
       staticPruned(false),
       staticUpperBoundMovement(0.0),
-      staticLowerBoundMovement(0.0)
+      staticLowerBoundMovement(0.0),
+      trueParent(node.Parent()),
+      trueLeft((node.NumChildren() == 0) ? NULL : &node.Child(0)),
+      trueRight((node.NumChildren() == 0) ? NULL : &node.Child(1))
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -77,6 +83,15 @@ class DTNNStatistic : public
   double StaticLowerBoundMovement() const { return staticLowerBoundMovement; }
   double& StaticLowerBoundMovement() { return staticLowerBoundMovement; }
 
+  void* TrueParent() const { return trueParent; }
+  void*& TrueParent() { return trueParent; }
+
+  void* TrueLeft() const { return trueLeft; }
+  void*& TrueLeft() { return trueLeft; }
+
+  void* TrueRight() const { return trueRight; }
+  void*& TrueRight() { return trueRight; }
+
   std::string ToString() const
   {
     std::ostringstream o;
@@ -98,6 +113,9 @@ class DTNNStatistic : public
   double staticUpperBoundMovement;
   double staticLowerBoundMovement;
   arma::vec centroid;
+  void* trueParent;
+  void* trueLeft;
+  void* trueRight;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list