[mlpack-git] master: Prune nodes whose points and children are pruned. This gives significant speedup, and on my little test dataset, this is the fastest algorithm yet that I have created. (536c639)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:27 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 536c639f8872a53c856c5636b7795edac6e53c44
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 19 13:45:30 2015 -0500

    Prune nodes whose points and children are pruned. This gives significant speedup, and on my little test dataset, this is the fastest algorithm yet that I have created.


>---------------------------------------------------------------

536c639f8872a53c856c5636b7795edac6e53c44
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 47 +++++++++++++++++++-------
 1 file changed, 34 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 9b70887..8a76d06 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -241,6 +241,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
   }
 
+  bool allPointsPruned = true;
   if (!node.Stat().StaticPruned())
   {
     // Try to prune individual points.
@@ -248,8 +249,11 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     {
       const size_t index = node.Point(i);
       if (!visited[index] && !prunedPoints[index])
+      {
+        allPointsPruned = false;
         continue; // We didn't visit it and we don't have valid bounds -- so we
                   // can't prune it.
+      }
 
       if (prunedLastIteration)
       {
@@ -287,11 +291,40 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
           // Point cannot be pruned.
           upperBounds[index] = DBL_MAX;
           lowerBounds[index] = DBL_MAX;
+          allPointsPruned = false;
         }
       }
     }
   }
-  else
+
+  // Recurse into children, and if all the children (and all the points) are
+  // pruned, then we can mark this as statically pruned.
+  bool allChildrenPruned = true;
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+  {
+    UpdateTree(node.Child(i), centroids, interclusterDistances);
+    if (!node.Child(i).Stat().StaticPruned())
+      allChildrenPruned = false;
+  }
+
+  // If all of the children and points are pruned, we may mark this node as
+  // pruned.
+  if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
+  {
+    node.Stat().StaticPruned() = true;
+    node.Stat().Owner() = centroids.n_cols; // Invalid owner.
+    node.Stat().Pruned() = size_t(-1);
+  }
+
+  if (!node.Stat().StaticPruned())
+  {
+    node.Stat().UpperBound() = DBL_MAX;
+    node.Stat().LowerBound() = DBL_MAX;
+    node.Stat().Pruned() = size_t(-1);
+    node.Stat().Owner() = centroids.n_cols;
+    node.Stat().StaticPruned() = false;
+  }
+  else // The node is now pruned.
   {
     if (prunedLastIteration)
     {
@@ -307,18 +340,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       node.Stat().StaticLowerBoundMovement() = 0.0;
     }
   }
-
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), centroids, interclusterDistances);
-
-  if (!node.Stat().StaticPruned())
-  {
-    node.Stat().UpperBound() = DBL_MAX;
-    node.Stat().LowerBound() = DBL_MAX;
-    node.Stat().Pruned() = size_t(-1);
-    node.Stat().Owner() = centroids.n_cols;
-    node.Stat().StaticPruned() = false;
-  }
 }
 
 template<typename MetricType, typename MatType, typename TreeType>



More information about the mlpack-git mailing list