[mlpack-git] master: Adapt to trees that aren't just binary. (55a7ba8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:06:04 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d

>---------------------------------------------------------------

commit 55a7ba8909ed26d5f9ce4427aae129d58f88cd27
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri May 8 13:13:34 2015 -0400

    Adapt to trees that aren't just binary.


>---------------------------------------------------------------

55a7ba8909ed26d5f9ce4427aae129d58f88cd27
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 51 +++++++++++-----------
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  | 23 +++++-----
 2 files changed, 36 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index b9c9c5b..fa66007 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -509,37 +509,38 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
   // If this is the root node, we can't coalesce.
   if (node.Parent() != NULL)
   {
-    if (node.Child(0).Stat().StaticPruned() &&
-        !node.Child(1).Stat().StaticPruned())
+    // First, we should coalesce those nodes that aren't statically pruned.
+    size_t numStaticallyPruned = 0;
+    size_t notPrunedIndex = 0;
+    for (size_t i = 0; i < node.NumChildren(); ++i)
     {
-      CoalesceTree(node.Child(1), 1);
-
-      // Link the right child to the parent.
-      node.Child(1).Parent() = node.Parent();
-      node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+      if (node.Child(i).Stat().StaticPruned())
+      {
+        ++numStaticallyPruned;
+      }
+      else
+      {
+        CoalesceTree(node.Child(i), i);
+        notPrunedIndex = i;
+      }
     }
-    else if (!node.Child(0).Stat().StaticPruned() &&
-             node.Child(1).Stat().StaticPruned())
-    {
-      CoalesceTree(node.Child(0), 0);
-
-      // Link the left child to the parent.
-      node.Child(0).Parent() = node.Parent();
-      node.Parent()->ChildPtr(child) = node.ChildPtr(0);
 
-    }
-    else if (!node.Child(0).Stat().StaticPruned() &&
-             !node.Child(1).Stat().StaticPruned())
+    // If we've pruned all but one child, then notPrunedIndex will contain the
+    // index of that child, and we can coalesce this node entirely.  Note that
+    // the case where all children are statically pruned should not happen,
+    // because then this node should itself be statically pruned.
+    if (numStaticallyPruned == node.NumChildren() - 1)
     {
-      // The conditional is probably not necessary.
-      CoalesceTree(node.Child(0), 0);
-      CoalesceTree(node.Child(1), 1);
+      node.Child(notPrunedIndex).Parent() = node.Parent();
+      node.Parent()->ChildPtr(child) = node.ChildPtr(notPrunedIndex);
     }
   }
   else
   {
-    CoalesceTree(node.Child(0), 0);
-    CoalesceTree(node.Child(1), 1);
+    // We can't coalesce the root, so call the children individually and
+    // coalesce them.
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      CoalesceTree(node.Child(i), i);
   }
 }
 
@@ -548,8 +549,8 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
     TreeType& node)
 {
   node.Parent() = (TreeType*) node.Stat().TrueParent();
-  node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
-  node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    node.ChildPtr(i) = (TreeType*) node.Stat().TrueChild(i);
 
   if (node.NumChildren() > 0)
   {
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 8aaec4b..09a14a0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -26,9 +26,7 @@ class DualTreeKMeansStatistic : public
       staticUpperBoundMovement(0.0),
       staticLowerBoundMovement(0.0),
       centroid(),
-      trueParent(NULL),
-      trueLeft(NULL),
-      trueRight(NULL)
+      trueParent(NULL)
   {
     // Nothing to do.
   }
@@ -43,9 +41,7 @@ class DualTreeKMeansStatistic : public
       staticPruned(false),
       staticUpperBoundMovement(0.0),
       staticLowerBoundMovement(0.0),
-      trueParent(node.Parent()),
-      trueLeft((node.NumChildren() == 0) ? NULL : &node.Child(0)),
-      trueRight((node.NumChildren() == 0) ? NULL : &node.Child(1))
+      trueParent(node.Parent())
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -57,6 +53,11 @@ class DualTreeKMeansStatistic : public
           node.Child(i).Stat().Centroid();
 
     centroid /= node.NumDescendants();
+
+    // Set the true children correctly.
+    trueChildren.resize(node.NumChildren());
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      trueChildren[i] = &node.Child(i);
   }
 
   double UpperBound() const { return upperBound; }
@@ -86,11 +87,8 @@ class DualTreeKMeansStatistic : public
   void* TrueParent() const { return trueParent; }
   void*& TrueParent() { return trueParent; }
 
-  void* TrueLeft() const { return trueLeft; }
-  void*& TrueLeft() { return trueLeft; }
-
-  void* TrueRight() const { return trueRight; }
-  void*& TrueRight() { return trueRight; }
+  void* TrueChild(const size_t i) const { return trueChildren[i]; }
+  void*& TrueChild(const size_t i) { return trueChildren[i]; }
 
   std::string ToString() const
   {
@@ -114,8 +112,7 @@ class DualTreeKMeansStatistic : public
   double staticLowerBoundMovement;
   arma::vec centroid;
   void* trueParent;
-  void* trueLeft;
-  void* trueRight;
+  std::vector<void*> trueChildren;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list