[mlpack-git] master: Refactoring to reduce runtime of tree update. It speeds things up in terms of distance computations too, a bit. (72f2ddd)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:17 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 72f2ddd51921b3417a4c264e396438aaf8330d13
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 3 20:36:29 2015 -0500

    Refactoring to reduce runtime of tree update. It speeds things up in terms of distance computations too, a bit.


>---------------------------------------------------------------

72f2ddd51921b3417a4c264e396438aaf8330d13
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 55 +++++++++++++-------------
 1 file changed, 28 insertions(+), 27 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 9d50113..842feae 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -250,7 +250,10 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       if (owner == centroids.n_cols + 1)
         owner = c;
       else if (owner != c)
+      {
         singleOwner = false;
+        break;
+      }
 
       // Update maximum cluster distance and second cluster bound.
       if (!prunedPoints[node.Point(i)])
@@ -274,10 +277,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     {
       if (owner == centroids.n_cols + 1)
         owner = node.Child(i).Stat().Owner();
-      else if (node.Child(i).Stat().Owner() == centroids.n_cols)
-        singleOwner = false;
-      else if (owner != node.Child(i).Stat().Owner())
+      else if ((node.Child(i).Stat().Owner() == centroids.n_cols) ||
+               (owner != node.Child(i).Stat().Owner()))
+      {
         singleOwner = false;
+        break;
+      }
 
       // Update maximum cluster distance and second cluster bound.
       if (node.Child(i).Stat().MaxClusterDistance() > newMaxClusterDistance)
@@ -286,19 +291,19 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
         newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
     }
 
-    // What do we do with the new cluster bounds?
-    if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
-        node.Stat().MaxClusterDistance())
-      node.Stat().MaxClusterDistance() = newMaxClusterDistance;
-    if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
-        node.Stat().SecondClusterBound())
-      node.Stat().SecondClusterBound() = newSecondClusterBound;
-
     // Okay, now we know if it's owned or not, and by which cluster.
     if (singleOwner)
     {
       node.Stat().Owner() = owner;
 
+      // What do we do with the new cluster bounds?
+      if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
+          node.Stat().MaxClusterDistance())
+        node.Stat().MaxClusterDistance() = newMaxClusterDistance;
+      if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
+          node.Stat().SecondClusterBound())
+        node.Stat().SecondClusterBound() = newSecondClusterBound;
+
       // Sanity check: ensure the owner is right.
 /*
       for (size_t i = 0; i < node.NumPoints(); ++i)
@@ -394,10 +399,10 @@ assignments(0, node.Point(i - 1)) << ".\n";
 //node.Child(0) << ", r\n" << node.Child(1) << ".\n";
 
       // Adjust the bounds for next iteration.
-      node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
-      node.Stat().SecondClusterBound() = std::max(0.0,
-          node.Stat().SecondClusterBound() -
-          clusterDistances[centroids.n_cols]);
+//      node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
+//      node.Stat().SecondClusterBound() = std::max(0.0,
+//          node.Stat().SecondClusterBound() -
+//          clusterDistances[centroids.n_cols]);
     }
   }
   else if (node.Stat().Pruned())
@@ -496,21 +501,17 @@ assignments(0, node.Point(i - 1)) << ".\n";
     {
       const size_t index = node.Point(i);
       size_t owner;
-      if (!prunedLastIteration && !prunedPoints[index])
-      {
-        owner = assignments(0, index);
-        // Establish bounds, since these points were searched this iteration.
-        lowerSecondBounds[index] = distances(1, index);
-      }
-      else if (prunedLastIteration && node.Stat().Owner() < centroids.n_cols)
-      {
+      if (prunedLastIteration && node.Stat().Owner() < centroids.n_cols)
         owner = node.Stat().Owner();
-      }
       else
-      {
-        owner = lastOwners[index];
-      }
+        owner = assignments(0, index);
+
+      // Update lower bound, if possible.
+      if (!prunedLastIteration && !prunedPoints[index])
+        lowerSecondBounds[index] = distances(1, index);
 
+      const double upperPointBound = distances(0, index) +
+          clusterDistances[owner];
       if (distances(0, index) + clusterDistances[owner] <
           lowerSecondBounds[index] - clusterDistances[centroids.n_cols])
       {



More information about the mlpack-git mailing list