[mlpack-git] master: Avoid iterating over every point when pruned. Cache the amount the upper bounds and lower bounds must change when the node becomes unpruned. (4ab1226)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:56 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 4ab12266cb68eef2a7c41dc280aba37ea98c312a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 17 18:40:47 2015 -0500

    Avoid iterating over every point when pruned. Cache the amount the upper bounds and lower bounds must change when the node becomes unpruned.


>---------------------------------------------------------------

4ab12266cb68eef2a7c41dc280aba37ea98c312a
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 25 ++++++++++++++++++++-----
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   | 14 +++++++++++++-
 2 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 5b43bc6..985c359 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -176,6 +176,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     std::vector<size_t>& oldFromNewCentroids,
     arma::mat& newCentroids)
 {
+  const bool prunedLastIteration = node.Stat().StaticPruned();
   node.Stat().StaticPruned() = false;
 
   // Grab information from the parent, if we can.
@@ -225,6 +226,14 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
         continue; // We didn't visit it and we don't have valid bounds -- so we
                   // can't prune it.
 
+      if (prunedLastIteration)
+      {
+        // It was pruned last iteration but not this iteration.
+        // Set the bounds correctly.
+        upperBounds[index] += node.Stat().StaticUpperBoundMovement();
+        lowerBounds[index] -= node.Stat().StaticLowerBoundMovement();
+      }
+
       prunedPoints[index] = false;
       const size_t owner = assignments[node.Point(i)];
       const double lowerBound = std::min(lowerBounds[index] -
@@ -257,12 +266,18 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
   }
   else
   {
-    // Adjust bounds for individual points.
-    for (size_t i = 0; i < node.NumDescendants(); ++i)
+    if (prunedLastIteration)
+    {
+      // Track total movement while pruned.
+      node.Stat().StaticUpperBoundMovement() +=
+          clusterDistances[node.Stat().Owner()];
+      node.Stat().StaticLowerBoundMovement() +=
+          clusterDistances[newCentroids.n_cols];
+    }
+    else
     {
-      upperBounds[node.Descendant(i)] += clusterDistances[node.Stat().Owner()];
-      lowerBounds[node.Descendant(i)] -=
-          clusterDistances[newCentroids.n_cols - 1];
+      node.Stat().StaticUpperBoundMovement() = 0.0;
+      node.Stat().StaticLowerBoundMovement() = 0.0;
     }
   }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 82e02c1..2601378 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -23,6 +23,8 @@ class DTNNStatistic : public
       owner(size_t(-1)),
       pruned(size_t(-1)),
       staticPruned(false),
+      staticUpperBoundMovement(0.0),
+      staticLowerBoundMovement(0.0),
       centroid()
   {
     // Nothing to do.
@@ -35,7 +37,9 @@ class DTNNStatistic : public
       lowerBound(DBL_MAX),
       owner(size_t(-1)),
       pruned(size_t(-1)),
-      staticPruned(false)
+      staticPruned(false),
+      staticUpperBoundMovement(0.0),
+      staticLowerBoundMovement(0.0)
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -67,6 +71,12 @@ class DTNNStatistic : public
   bool StaticPruned() const { return staticPruned; }
   bool& StaticPruned() { return staticPruned; }
 
+  double StaticUpperBoundMovement() const { return staticUpperBoundMovement; }
+  double& StaticUpperBoundMovement() { return staticUpperBoundMovement; }
+
+  double StaticLowerBoundMovement() const { return staticLowerBoundMovement; }
+  double& StaticLowerBoundMovement() { return staticLowerBoundMovement; }
+
   std::string ToString() const
   {
     std::ostringstream o;
@@ -85,6 +95,8 @@ class DTNNStatistic : public
   size_t owner;
   size_t pruned;
   bool staticPruned;
+  double staticUpperBoundMovement;
+  double staticLowerBoundMovement;
   arma::vec centroid;
 };
 



More information about the mlpack-git mailing list