[mlpack-git] master: Refactoring, and tighten a bound for minor speedup. (820ba74)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:36 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 820ba7490a2a244aae6b9e27b34492bf636b1a7a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 3 21:22:54 2015 -0500

    Refactoring, and tighten a bound for minor speedup.


>---------------------------------------------------------------

820ba7490a2a244aae6b9e27b34492bf636b1a7a
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 204 ++++---------------------
 1 file changed, 26 insertions(+), 178 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 842feae..40dec27 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -304,105 +304,29 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
           node.Stat().SecondClusterBound())
         node.Stat().SecondClusterBound() = newSecondClusterBound;
 
-      // Sanity check: ensure the owner is right.
-/*
-      for (size_t i = 0; i < node.NumPoints(); ++i)
-      {
-        arma::vec dists(centroids.n_cols);
-        size_t trueOwner = centroids.n_cols;
-        double trueDist = DBL_MAX;
-        for (size_t j = 0; j < centroids.n_cols; ++j)
-        {
-          const double dist = metric.Evaluate(dataset.col(node.Point(i)),
-              lastIterationCentroids.col(j));
-          dists(j) = dist;
-          if (dist < trueDist)
-          {
-            trueDist = dist;
-            trueOwner = j;
-          }
-        }
-
-        if (trueOwner != owner)
-        {
-          Log::Warn << node << "...\n" << *node.Parent();
-          Log::Warn << dists.t();
-          Log::Warn << "Assignment: " << assignments(0, node.Point(i)) << ".\n";
-          Log::Warn << "Dists: " << distances(0, node.Point(i)) << ", " <<
-distances(1, node.Point(i)) << ".\n";
-//            TreeType* n = node.Parent()->Parent();
-//            while (n != NULL)
-//            {
-//              Log::Warn << "...\n" << *n;
-//              n = n->Parent();
-//            }
-          Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
-                << owner << " but has true owner " << trueOwner << "! [" <<
-assignments(0, node.Point(i)) << " -- " <<
-metric.Evaluate(dataset.col(node.Point(i)),
-centroids.col(assignments(0, node.Point(i)))) << "] " <<
-distances(0, node.Point(i)) << " " <<
-assignments(0, node.Point(i)) << " " <<
-assignments(0, node.Point(i - 1)) << ".\n";
-        }
-      }
-*/
-
-      if (node.NumPoints() == 0 && childrenPruned)
-      {
-        // Pruned because its children are all pruned.
+      // Convenience variables to clean up the expressions.
+      const double mcd = node.Stat().MaxClusterDistance();
+      const double scb = node.Stat().SecondClusterBound();
+      const double ownerMovement = clusterDistances[owner];
+      const double maxMovement = clusterDistances[centroids.n_cols];
+      const double closestClusterDistance =
+          interclusterDistances[newFromOldCentroids[owner]];
+      if ((node.NumPoints() == 0 && childrenPruned) ||
+          (mcd + ownerMovement < scb - maxMovement) ||
+          (mcd < 0.5 * closestClusterDistance))
         node.Stat().Pruned() = true;
-      }
-      // What is the maximum distance to the closest cluster in the node?
-      else if (node.Stat().MaxClusterDistance() +
-          clusterDistances[node.Stat().Owner()] <
-          node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
-      {
-        node.Stat().Pruned() = true;
-      }
-      else
-      {
-        // Also do between-cluster prune.
-        if (node.Stat().MaxClusterDistance() < 0.5 *
-            interclusterDistances[newFromOldCentroids[owner]])
-        {
-          node.Stat().Pruned() = true;
-        }
-      }
 
-      // Adjust for next iteration.
-      node.Stat().MaxClusterDistance() +=
-          clusterDistances[node.Stat().Owner()];
-      node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
+      // Adjust bounds for next iteration, regardless of whether or not the node
+      // was pruned.  (Does this adjustment need to happen if there is no prune?
+      node.Stat().MaxClusterDistance() += ownerMovement;
+      node.Stat().SecondClusterBound() -= maxMovement;
     }
-    else
+    else if (childrenPruned && node.NumChildren() > 0 && node.NumPoints() == 0)
     {
       // The node isn't owned by a single cluster.  But if it has no points and
       // its children are all pruned, we may prune it too.
-      if (childrenPruned && node.NumChildren() > 0)
-      {
-//        Log::Warn << "Prune parent node " << node.Point(0) << "c" <<
-//node.NumDescendants() << ".\n";
-        node.Stat().Pruned() = true;
-        node.Stat().Owner() = centroids.n_cols;
-      }
-//      if (node.NumChildren() > 0)
-//        if (node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
-//          Log::Warn << "Node left child pruned but right child not:\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ", this:\n" << node;
-//      if (node.NumChildren() > 0)
-//        if (node.Child(1).Stat().Pruned() && !node.Child(0).Stat().Pruned())
-//          Log::Warn << "Node right child pruned but left child not:\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ", this:\n" << node;
-//      if (node.NumChildren() > 0)
-//        Log::Warn << "Node has more than 0 children: " << node << ".l\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ".\n";
-
-      // Adjust the bounds for next iteration.
-//      node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
-//      node.Stat().SecondClusterBound() = std::max(0.0,
-//          node.Stat().SecondClusterBound() -
-//          clusterDistances[centroids.n_cols]);
+      node.Stat().Pruned() = true;
+      node.Stat().Owner() = centroids.n_cols;
     }
   }
   else if (node.Stat().Pruned())
@@ -447,37 +371,6 @@ assignments(0, node.Point(i - 1)) << ".\n";
         }
       }
     }
-/*
-    if (node.Stat().Pruned() && node.Stat().Owner() != centroids.n_cols)
-    {
-      for (size_t i = 0; i < node.NumPoints(); ++i)
-      {
-        size_t trueOwner = 0;
-        double ownerDist = DBL_MAX;
-        arma::vec distances(centroids.n_cols);
-        for (size_t j = 0; j < centroids.n_cols; ++j)
-        {
-          const double dist = metric.Evaluate(dataset.col(node.Point(i)),
-              lastIterationCentroids.col(j));
-          distances(j) = dist;
-          if (dist < ownerDist)
-          {
-            trueOwner = j;
-            ownerDist = dist;
-          }
-        }
-
-        if (trueOwner != node.Stat().Owner())
-        {
-            Log::Warn << node << "...\n" << *node.Parent();
-            Log::Warn << distances.t();
-            Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
-                << node.Stat().Owner() << " but has true owner " << trueOwner <<
-"!\n";
-        }
-      }
-    }
-*/
   }
   else
   {
@@ -512,47 +405,12 @@ assignments(0, node.Point(i - 1)) << ".\n";
 
       const double upperPointBound = distances(0, index) +
           clusterDistances[owner];
-      if (distances(0, index) + clusterDistances[owner] <
-          lowerSecondBounds[index] - clusterDistances[centroids.n_cols])
-      {
-/*
-        // Sanity check.
-        size_t trueOwner;
-        double trueDist = DBL_MAX;
-        arma::vec distances(centroids.n_cols);
-        for (size_t j = 0; j < centroids.n_cols; ++j)
-        {
-          const double dist = metric.Evaluate(lastIterationCentroids.col(j),
-                                              dataset.col(index));
-          distances(j) = dist;
-          if (dist < trueDist)
-          {
-            trueOwner = j;
-            trueDist = dist;
-          }
-        }
-
-        if (trueOwner != owner)
-        {
-          Log::Warn << "Point " << index << ", ub " << distances(0, index) << ","
-              << " lb " << lowerSecondBounds[index] << ", pruned " <<
-prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
-"owner!\n";
-          Log::Warn << distances.t();
-          Log::Fatal << "Assigned owner " << owner << " but true owner is "
-              << trueOwner << "!\n";
-        }
-*/
-        prunedPoints[index] = true;
-        distances(0, index) += clusterDistances[owner];
-        lastOwners[index] = owner;
-        distances(1, index) += clusterDistances[centroids.n_cols];
-        lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
-        prunedCentroids.col(owner) += dataset.col(index);
-        prunedCounts(owner)++;
-      }
-      else if (distances(0, index) + clusterDistances[owner] < 0.5 *
-               interclusterDistances[newFromOldCentroids[owner]])
+      const double lowerSecondBound = lowerSecondBounds[index] -
+          clusterDistances[centroids.n_cols];
+      const double closestClusterDistance =
+          interclusterDistances[newFromOldCentroids[owner]];
+      if ((upperPointBound < lowerSecondBound) ||
+          (upperPointBound < 0.5 * closestClusterDistance))
       {
         prunedPoints[index] = true;
         distances(0, index) += clusterDistances[owner];
@@ -568,18 +426,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
         distances(0, index) = metric.Evaluate(centroids.col(owner),
                                              dataset.col(index));
         ++distanceCalculations;
-        if (distances(0, index) < lowerSecondBounds[index] -
-            clusterDistances[centroids.n_cols])
-        {
-          prunedPoints[index] = true;
-          lastOwners[index] = owner;
-          lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
-          distances(1, index) += clusterDistances[centroids.n_cols];
-          prunedCentroids.col(owner) += dataset.col(index);
-          prunedCounts(owner)++;
-        }
-        else if (distances(0, index) < 0.5 *
-                 interclusterDistances[newFromOldCentroids[owner]])
+
+        if ((distances(0, index) < lowerSecondBound) ||
+            (distances(0, index) < 0.5 * closestClusterDistance))
         {
           prunedPoints[index] = true;
           lastOwners[index] = owner;
@@ -593,7 +442,6 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
           prunedPoints[index] = false;
           allPruned = false;
           // Still update these anyway.
-          distances(0, index) += clusterDistances[owner];
           distances(1, index) += clusterDistances[centroids.n_cols];
         }
       }



More information about the mlpack-git mailing list