[mlpack-git] master: Refactor into UpdateOwner(), instead of an ugly loop at the beginning of TreeUpdate(). (2633583)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:12 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 2633583a44a2ef6776721199f28291f02cae75b5
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jan 28 16:17:01 2015 -0500

    Refactor into UpdateOwner(), instead of an ugly loop at the beginning of TreeUpdate().


>---------------------------------------------------------------

2633583a44a2ef6776721199f28291f02cae75b5
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |   4 +
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 106 +++++++++------------
 2 files changed, 48 insertions(+), 62 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 9e0c17a..1d26273 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -61,6 +61,10 @@ class DualTreeKMeans
   void ClusterTreeUpdate(TreeType* node,
                          const arma::mat& distances);
 
+  void UpdateOwner(TreeType* node,
+                   const size_t clusters,
+                   const arma::Col<size_t>& assignments) const;
+
   void TreeUpdate(TreeType* node,
                   const size_t clusters,
                   const arma::vec& clusterDistances,
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 7a65d3e..756cceb 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -127,6 +127,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // Update the tree with the centroid movement information.
   size_t hamerlyPruned = 0;
+  UpdateOwner(tree, centroids.n_cols, assignments);
   TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
       oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned);
 
@@ -174,6 +175,47 @@ bool IsDescendantOf(
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateOwner(
+    TreeType* node,
+    const size_t clusters,
+    const arma::Col<size_t>& assignments) const
+{
+  size_t owner = clusters + 1;
+  bool same = true;
+  for (size_t i = 0; i < node->NumChildren(); ++i)
+  {
+    UpdateOwner(&node->Child(i), clusters, assignments);
+    if (owner == clusters + 1)
+      owner = node->Child(i).Stat().Owner();
+    else if (owner != node->Child(i).Stat().Owner())
+    {
+      same = false;
+      owner = clusters;
+      break;
+    }
+  }
+
+  if (same)
+  {
+    for (size_t i = 0; i < node->NumPoints(); ++i)
+    {
+      if (owner == clusters + 1)
+        owner = assignments[node->Point(i)];
+      else if (owner != assignments[node->Point(i)])
+      {
+        same = false;
+        break;
+      }
+    }
+  }
+
+  if (same)
+    node->Stat().Owner() = owner;
+  else
+    node->Stat().Owner() = clusters;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     TreeType* node,
     const size_t clusters,
@@ -186,45 +228,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
 {
   // This is basically IterationUpdate(), but pulled out to be separate from the
   // actual dual-tree algorithm.
-  if (node->Begin() == 26038)
-    Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << ".\n";
-  if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
-    node->Stat().Owner() = node->Parent()->Stat().Owner();
-  if (node->Begin() == 26038)
-    Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << " after parent check.\n";
-
-  const size_t cluster = assignments[node->Descendant(0)];
-  bool allSame = true;
-  for (size_t i = 1; i < node->NumDescendants(); ++i)
-  {
-    if (assignments[node->Descendant(i)] != cluster)
-    {
-      allSame = false;
-      break;
-    }
-  }
-
-  if (allSame)
-    node->Stat().Owner() = cluster;
-  else
-    node->Stat().Owner() = centroids.n_cols;
-  if (node->Begin() == 26038)
-    Log::Warn << "r26038c" << node->Count() << " has manually set owner " <<
-node->Stat().Owner() << ".\n";
-
   const bool prunedLastIteration = node->Stat().HamerlyPruned();
   node->Stat().HamerlyPruned() = false;
 
-  if (node->Begin() == 26038)
-    Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << ".\n";
-
   // The easy case: this node had an owner.
   if (node->Stat().Owner() < clusters)
   {
     // Verify correctness...
+    /*
     for (size_t i = 0; i < node->NumDescendants(); ++i)
     {
       size_t closest = clusters;
@@ -251,6 +262,7 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
 ".\n";
       }
     }
+    */
 
     // During the last iteration, this node was pruned.
     const size_t owner = node->Stat().Owner();
@@ -263,11 +275,6 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
     {
       // Can we continue being Hamerly pruned?  If not, we'll have to update the
       // bound next iteration.
-      if (node->Begin() == 26038)
-        Log::Warn << "r26038c" << node->Count() << ": check sustained Hamerly "
-            << "prune with MQND " << node->Stat().MaxQueryNodeDistance() << ", "
-            << "lscb " << node->Stat().LastSecondClosestBound() << ", cd "
-            << clusterDistances[clusters] << ".\n";
       if (node->Stat().MaxQueryNodeDistance() <
           node->Stat().LastSecondClosestBound() - clusterDistances[clusters])
       {
@@ -278,22 +285,6 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
     }
     else
     {
-      if (node->Begin() == 26038)
-      {
-        if (node->Stat().ClosestQueryNode() != NULL)
-          Log::Warn << "r26038c" << node->Count() << " CQN: " << ((TreeType*)
-  node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
-  node->Stat().ClosestQueryNode())->Count() << ".\n";
-        if (node->Stat().SecondClosestQueryNode() != NULL)
-          Log::Warn << "r26038c" << node->Count() << " SCQN: " << ((TreeType*)
-  node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
-  node->Stat().SecondClosestQueryNode())->Count() << ".\n";
-        Log::Warn << "Attempt hamerly prune r26038c" << node->Count() << " with "
-            << "MQND " << node->Stat().MaxQueryNodeDistance() << " and smqnd "
-            << node->Stat().SecondMinQueryNodeDistance() << " and cluster d "
-            << clusterDistances[clusters] << ".\n";
-      }
-
       // Now we check for a Hamerly prune.  We know that we have an accurate
       // second bound since nothing can be pruned.
       if (node->Stat().MaxQueryNodeDistance() /* already adjusted */ <
@@ -339,16 +330,7 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
     allPruned = false;
 
   if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
-  {
-    if (node->Begin() == 26038)
-      Log::Warn << "Set r" << node->Begin() << "c" << node->Count() << " to be "
-          << "Hamerly pruned.\n";
     node->Stat().HamerlyPruned() = true;
-  }
-
-  if (node->Begin() == 26038 && node->Stat().HamerlyPruned())
-    Log::Warn << "r" << node->Begin() << "c" << node->Count() << " is Hamerly "
-        << "pruned.\n";
 
   node->Stat().Iteration() = iteration;
   node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;



More information about the mlpack-git mailing list