[mlpack-git] master: Perform prunes on individual points. Significant speedup with respect to number of calculations, no real speedup for runtime. But there is still time to optimize it. (a0a6bc1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:05 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit a0a6bc1479ada991246676d1146e9e3cb147ed74
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 2 17:17:59 2015 -0500

    Perform prunes on individual points. Significant speedup with respect to number of calculations, no real speedup for runtime. But there is still time to optimize it.


>---------------------------------------------------------------

a0a6bc1479ada991246676d1146e9e3cb147ed74
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |   9 ++
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 135 ++++++++++++++++++++++---
 src/mlpack/methods/kmeans/dtnn_rules.hpp       |   5 +-
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  |  23 +++--
 4 files changed, 146 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 211eda3..abc8236 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -84,6 +84,15 @@ class DTNNKMeans
   //! Counts from pruning.  Not normalized.
   arma::Col<size_t> prunedCounts;
 
+  //! Upper bounds on cluster distances for each point.
+  arma::vec upperBounds;
+  //! Lower bounds on second closest cluster distance for each point.
+  arma::vec lowerSecondBounds;
+  //! Indicator of whether or not the point is pruned.
+  std::vector<bool> prunedPoints;
+  //! The last cluster each point was assigned to.
+  arma::Col<size_t> lastOwners;
+
   //! Update the bounds in the tree before the next iteration.
   void UpdateTree(TreeType& node,
                   const double tolerance,
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 02471be..deb98d1 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -54,6 +54,12 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
     distanceCalculations(0),
     iteration(0)
 {
+  prunedPoints.resize(dataset.n_cols, false); // Fill with false.
+  upperBounds.set_size(dataset.n_cols);
+  upperBounds.fill(DBL_MAX);
+  lowerSecondBounds.zeros(dataset.n_cols);
+  lastOwners.zeros(dataset.n_cols);
+
   Timer::Start("tree_building");
 
   // Copy the dataset, if necessary.
@@ -118,7 +124,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   distances.fill(DBL_MAX);
   assignments.fill(size_t(-1));
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
-  RuleType rules(centroids, dataset, assignments, distances, metric);
+  RuleType rules(centroids, dataset, assignments, distances, metric,
+      prunedPoints);
 
   // Now construct the traverser ourselves.
   typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -180,13 +187,13 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   clusterDistances[centroids.n_cols] = maxMovement;
   distanceCalculations += centroids.n_cols;
 
+  // Reset centroids and counts for things we will collect during pruning.
+  prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  prunedCounts.zeros(centroids.n_cols);
   UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
       clusterDistances, oldFromNewCentroids, interclusterDistances,
       newFromOldCentroids);
 
-  // Reset centroids and counts for things we will collect during pruning.
-  prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
-  prunedCounts.zeros(centroids.n_cols);
   PrecalculateCentroids(*tree);
 
   delete centroidTree;
@@ -225,6 +232,8 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       childrenPruned = false; // Not all children are pruned.
   }
 
+  const bool prunedLastIteration = node.Stat().Pruned();
+
   // Does the node have a single owner?
   // It would be nice if we could do this during the traversal.
   bool singleOwner = true;
@@ -236,19 +245,35 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     for (size_t i = 0; i < node.NumPoints(); ++i)
     {
       // Don't forget to map back from the new cluster index.
-      const size_t c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
-          oldFromNewCentroids[assignments(0, node.Point(i))] :
-          assignments(0, node.Point(i));
+      size_t c;
+      if (!prunedPoints[node.Point(i)])
+        c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+            oldFromNewCentroids[assignments(0, node.Point(i))] :
+            assignments(0, node.Point(i));
+      else
+        c = lastOwners[node.Point(i)];
+
       if (owner == centroids.n_cols + 1)
         owner = c;
       else if (owner != c)
         singleOwner = false;
 
       // Update maximum cluster distance and second cluster bound.
-      if (distances(0, node.Point(i)) > newMaxClusterDistance)
-        newMaxClusterDistance = distances(0, node.Point(i));
-      if (distances(1, node.Point(i)) < newSecondClusterBound)
-        newSecondClusterBound = distances(1, node.Point(i));
+      if (!prunedPoints[node.Point(i)])
+      {
+        if (distances(0, node.Point(i)) > newMaxClusterDistance)
+          newMaxClusterDistance = distances(0, node.Point(i));
+        if (distances(1, node.Point(i)) < newSecondClusterBound)
+          newSecondClusterBound = distances(1, node.Point(i));
+      }
+      else
+      {
+        // Use the cached bounds.
+        if (upperBounds[node.Point(i)] > newMaxClusterDistance)
+          newMaxClusterDistance = upperBounds[node.Point(i)];
+        if (lowerSecondBounds[node.Point(i)] < newSecondClusterBound)
+          newSecondClusterBound = lowerSecondBounds[node.Point(i)];
+      }
     }
 
     for (size_t i = 0; i < node.NumChildren(); ++i)
@@ -309,7 +334,8 @@ oldFromNewCentroids[assignments(0, node.Point(i))] << " " <<
 oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
           }
         }
-      }*/
+      }
+*/
 
       // What is the maximum distance to the closest cluster in the node?
       if (node.Stat().MaxClusterDistance() +
@@ -323,7 +349,9 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
         // Also do between-cluster prune.
         if (node.Stat().MaxClusterDistance() < 0.5 *
             interclusterDistances[newFromOldCentroids[owner]])
+        {
           node.Stat().Pruned() = true;
+        }
       }
 
       // Adjust for next iteration.
@@ -372,8 +400,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
                 << node.Stat().Owner() << " but has true owner " << trueOwner <<
 "!\n";
         }
-      }
-*/
+      }*/
 
     // Will our bounds still work?
     if (node.Stat().MaxClusterDistance() +
@@ -419,6 +446,86 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
           clusterDistances[centroids.n_cols]);
   }
 
+  // If the node wasn't pruned, try to prune individual points.
+  if (!node.Stat().Pruned())
+  {
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+    {
+      const size_t index = node.Point(i);
+      size_t owner;
+      if (!prunedLastIteration && !prunedPoints[index])
+      {
+        owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+            oldFromNewCentroids[assignments(0, index)] : assignments(0, index);
+        // Establish bounds, since these points were searched this iteration.
+        upperBounds[index] = distances(0, index);
+        lowerSecondBounds[index] = distances(1, index);
+      }
+      else if (prunedLastIteration)
+      {
+        owner = node.Stat().Owner();
+      }
+      else
+      {
+        owner = lastOwners[index];
+      }
+
+      if (upperBounds[index] + clusterDistances[owner] <
+          lowerSecondBounds[index] - clusterDistances[centroids.n_cols])
+      {
+/*
+        // Sanity check.
+        size_t trueOwner;
+        double trueDist = DBL_MAX;
+        arma::vec distances(centroids.n_cols);
+        for (size_t j = 0; j < centroids.n_cols; ++j)
+        {
+          const double dist = metric.Evaluate(centroids.col(j),
+                                              dataset.col(index));
+          distances(j) = dist;
+          if (dist < trueDist)
+          {
+            trueOwner = j;
+            trueDist = dist;
+          }
+        }
+
+        if (trueOwner != owner)
+        {
+          Log::Warn << "Point " << index << ", ub " << upperBounds[index] << ","
+              << " lb " << lowerSecondBounds[index] << ", pruned " <<
+prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
+"owner!\n";
+          Log::Warn << distances.t();
+          Log::Fatal << "Assigned owner " << owner << " but true owner is "
+              << trueOwner << "!\n";
+        }*/
+
+        prunedPoints[index] = true;
+        upperBounds[index] += clusterDistances[owner];
+        lastOwners[index] = owner;
+        lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
+        prunedCentroids.col(owner) += dataset.col(index);
+        prunedCounts(owner)++;
+      }
+      else
+      {
+        prunedPoints[index] = false;
+      }
+    }
+  }
+
+  if (node.Stat().Pruned())
+  {
+    // Update bounds.
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+    {
+      const size_t index = node.Point(i);
+      upperBounds[index] += clusterDistances[node.Stat().Owner()];
+      lowerSecondBounds[index] -= clusterDistances[node.Stat().Owner()];
+    }
+  }
+
   if (node.Stat().FirstBound() != DBL_MAX)
     node.Stat().FirstBound() += tolerance;
   if (node.Stat().SecondBound() != DBL_MAX)
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 44647ce..da2a63d 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -22,7 +22,8 @@ class DTNNKMeansRules
                       const arma::mat& dataset,
                       arma::Mat<size_t>& neighbors,
                       arma::mat& distances,
-                      MetricType& metric);
+                      MetricType& metric,
+                      const std::vector<bool>& prunedPoints);
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
@@ -50,6 +51,8 @@ class DTNNKMeansRules
 
   typename neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
       MetricType, TreeType> rules;
+
+  const std::vector<bool>& prunedPoints;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index e3d0c55..bce7d47 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -18,8 +18,10 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
     const arma::mat& dataset,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances,
-    MetricType& metric) :
-    rules(centroids, dataset, neighbors, distances, metric)
+    MetricType& metric,
+    const std::vector<bool>& prunedPoints) :
+    rules(centroids, dataset, neighbors, distances, metric),
+    prunedPoints(prunedPoints)
 {
   // Nothing to do.
 }
@@ -29,10 +31,10 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
     const size_t queryIndex,
     const size_t referenceIndex)
 {
-  // We'll check if the query point has been Hamerly pruned.  If so, don't
-  // continue.
-//  if (queryIndex == 27040)
-//    Log::Warn << "Visit point 27040 with cluster " << referenceIndex << ".\n";
+  // We'll check if the query point has been pruned.  If so, don't continue.
+  if (prunedPoints[queryIndex])
+    return 0.0; // Returning 0 shouldn't be a problem.
+
   return rules.BaseCase(queryIndex, referenceIndex);
 }
 
@@ -41,6 +43,10 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     const size_t queryIndex,
     TreeType& referenceNode)
 {
+  // If the query point has already been pruned, then don't recurse further.
+  if (prunedPoints[queryIndex])
+    return DBL_MAX;
+
   return rules.Score(queryIndex, referenceNode);
 }
 
@@ -49,11 +55,6 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
-//  if (queryNode.Point(0) == 27040)
-//    Log::Warn << "Visit q27040c1 r" << referenceNode.Point(0) << "c" <<
-//referenceNode.NumDescendants() << ", " << queryNode.Stat().Pruned() << ", " <<
-//queryNode.Stat() << ", " << queryNode.Stat().FirstBound() << "," <<
-//queryNode.Stat().SecondBound() << ", " << queryNode.Stat().Bound() << ".\n";
   if (queryNode.Stat().Pruned())
     return DBL_MAX;
 



More information about the mlpack-git mailing list