[mlpack-git] master: Add static point prunes. Fairly significant runtime improvement. (99b8d56)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:31 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 99b8d56dc1d29bea6096648e92e88782d002eb4d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 17 18:34:02 2015 -0500

    Add static point prunes. Fairly significant runtime improvement.


>---------------------------------------------------------------

99b8d56dc1d29bea6096648e92e88782d002eb4d
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 109 +++++++++++++++++++++++--
 src/mlpack/tests/kmeans_test.cpp               |  12 ++-
 2 files changed, 109 insertions(+), 12 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 4f5377c..5b43bc6 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -69,6 +69,12 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
   tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
 
   Timer::Stop("tree_building");
+
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+    prunedPoints[i] = false;
+  assignments.fill(size_t(-1));
+  upperBounds.fill(DBL_MAX);
+  lowerBounds.fill(DBL_MAX);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -86,13 +92,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::Col<size_t>& counts)
 {
   // Reset information.
-  upperBounds.fill(DBL_MAX);
-  lowerBounds.fill(DBL_MAX);
   for (size_t i = 0; i < dataset.n_cols; ++i)
-  {
-    prunedPoints[i] = false;
     visited[i] = false;
-  }
 
   // Build a tree on the centroids.
   arma::mat oldCentroids(centroids); // Slow. :(
@@ -177,6 +178,16 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
 {
   node.Stat().StaticPruned() = false;
 
+  // Grab information from the parent, if we can.
+  if (node.Parent() != NULL &&
+      node.Parent()->Stat().Pruned() == clusterDistances.n_elem - 1)
+  {
+    node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
+    node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
+    node.Stat().Pruned() = node.Parent()->Stat().Pruned();
+    node.Stat().Owner() = node.Parent()->Stat().Owner();
+  }
+
   if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
       (node.Stat().Owner() < clusterDistances.n_elem - 1))
   {
@@ -199,6 +210,65 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       }
     }
   }
+  else
+  {
+    node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+  }
+
+  if (!node.Stat().StaticPruned())
+  {
+    // Try to prune individual points.
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+    {
+      const size_t index = node.Point(i);
+      if (!visited[index] && !prunedPoints[index])
+        continue; // We didn't visit it and we don't have valid bounds -- so we
+                  // can't prune it.
+
+      prunedPoints[index] = false;
+      const size_t owner = assignments[node.Point(i)];
+      const double lowerBound = std::min(lowerBounds[index] -
+          clusterDistances[newCentroids.n_cols], node.Stat().LowerBound());
+      if (upperBounds[index] + clusterDistances[owner] < lowerBound)
+      {
+        prunedPoints[index] = true;
+        upperBounds[index] += clusterDistances[owner];
+        lowerBounds[index] = lowerBound;
+      }
+      else
+      {
+        // Attempt to tighten the bound.
+        upperBounds[index] = metric.Evaluate(dataset.col(index),
+                                             newCentroids.col(owner));
+        ++distanceCalculations;
+        if (upperBounds[index] < lowerBound)
+        {
+          prunedPoints[index] = true;
+          lowerBounds[index] = lowerBound;
+        }
+        else
+        {
+          // Point cannot be pruned.
+          upperBounds[index] = DBL_MAX;
+          lowerBounds[index] = DBL_MAX;
+        }
+      }
+    }
+  }
+  else
+  {
+    // Adjust bounds for individual points.
+    for (size_t i = 0; i < node.NumDescendants(); ++i)
+    {
+      upperBounds[node.Descendant(i)] += clusterDistances[node.Stat().Owner()];
+      lowerBounds[node.Descendant(i)] -=
+          clusterDistances[newCentroids.n_cols - 1];
+    }
+  }
+
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
+        newCentroids);
 
   if (!node.Stat().StaticPruned())
   {
@@ -208,10 +278,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     node.Stat().Owner() = clusterDistances.n_elem - 1;
     node.Stat().StaticPruned() = false;
   }
-
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
-        newCentroids);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -267,6 +333,31 @@ node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
         const size_t owner = assignments[node.Point(i)];
         newCentroids.col(owner) += dataset.col(node.Point(i));
         ++newCounts[owner];
+
+/*
+        const size_t index = node.Point(i);
+        arma::vec trueDistances(centroids.n_cols);
+        for (size_t j = 0; j < centroids.n_cols; ++j)
+        {
+          const double dist = metric.Evaluate(dataset.col(index),
+                                              centroids.col(j));
+          trueDistances[j] = dist;
+        }
+
+        arma::uword minIndex;
+        const double minDist = trueDistances.min(minIndex);
+        if (size_t(minIndex) != owner)
+        {
+          Log::Warn << trueDistances.t();
+          Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
+  << node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
+        << "distance " << minDist << " but was assigned to cluster " <<
+assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
+" and lb " << lowerBounds[node.Point(0)] << "; pp " <<
+(prunedPoints[node.Point(0)] ? "true" : "false") << ", visited " << (visited[node.Point(0)] ? "true"
+: "false") << ".\n";
+        }
+*/
       }
     }
 
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index d274e49..562589d 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -711,7 +711,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansBaseCaseTest)
   upperBounds.fill(DBL_MAX);
   lowerBounds.fill(DBL_MAX);
   std::vector<bool> visited(points, false); // Fill with false.
-  std::vector<size_t> oldFromNewCentroids; // Not used.
+  std::vector<size_t> oldFromNewCentroids(clusters);
+  for (size_t i = 0; i < clusters; ++i)
+    oldFromNewCentroids[i] = i;
   std::vector<bool> prunedPoints(points, false); // Fill with false.
 
   EuclideanDistance e;
@@ -768,7 +770,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeOneLeafTest)
   upperBounds.fill(DBL_MAX);
   lowerBounds.fill(DBL_MAX);
   std::vector<bool> visited(points, false); // Fill with false.
-  std::vector<size_t> oldFromNewCentroids; // Not used.
+  std::vector<size_t> oldFromNewCentroids(clusters);
+  for (size_t i = 0; i < clusters; ++i)
+    oldFromNewCentroids[i] = i;
   std::vector<bool> prunedPoints(points, false); // Fill with false.
 
   EuclideanDistance e;
@@ -876,7 +880,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeTest)
   upperBounds.fill(DBL_MAX);
   lowerBounds.fill(DBL_MAX);
   std::vector<bool> visited(points, false); // Fill with false.
-  std::vector<size_t> oldFromNewCentroids; // Not used.
+  std::vector<size_t> oldFromNewCentroids(clusters);
+  for (size_t i = 0; i < clusters; ++i)
+    oldFromNewCentroids[i] = i;
   std::vector<bool> prunedPoints(points, false); // Fill with false.
 
   EuclideanDistance e;



More information about the mlpack-git mailing list