[mlpack-git] master: Refactor to re-use distances to clusters. This is better than setting the distances to DBL_MAX every iteration, and provides reasonable speedup. (76e10e0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:19 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 76e10e046328e8f24b1db03d80b70e35cbf4512c
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 3 17:22:27 2015 -0500

    Refactor to re-use distances to clusters. This is better than setting the distances to DBL_MAX every iteration, and provides reasonable speedup.


>---------------------------------------------------------------

76e10e046328e8f24b1db03d80b70e35cbf4512c
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  2 ++
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 38 ++++++++++++++++++++++----
 src/mlpack/methods/kmeans/dtnn_rules.hpp       |  5 +++-
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  |  9 ++++--
 4 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index b939f32..f9b35bb 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -99,6 +99,8 @@ class DTNNKMeans
   arma::mat distances;
   arma::Mat<size_t> assignments;
 
+  std::vector<bool> visited; // Was the point visited this iteration?
+
   arma::mat lastIterationCentroids; // For sanity checks.
 
   //! Update the bounds in the tree before the next iteration.
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 95546d3..ae2dde7 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -62,6 +62,13 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
   lowerSecondBounds.zeros(dataset.n_cols);
   lastOwners.zeros(dataset.n_cols);
 
+  assignments.set_size(2, dataset.n_cols);
+  assignments.fill(size_t(-1));
+  distances.set_size(2, dataset.n_cols);
+  distances.fill(DBL_MAX);
+
+  visited.resize(dataset.n_cols, false);
+
   Timer::Start("tree_building");
 
   // Copy the dataset, if necessary.
@@ -135,11 +142,9 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
-  distances.fill(DBL_MAX);
-  assignments.fill(size_t(-1));
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, distances, metric,
-      prunedPoints, oldFromNewCentroids);
+      prunedPoints, oldFromNewCentroids, visited);
 
   // Now construct the traverser ourselves.
   typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -153,10 +158,12 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   // From the assignments, calculate the new centroids and counts.
   for (size_t i = 0; i < dataset.n_cols; ++i)
   {
-    if (assignments(0, i) != size_t(-1))
+    if (visited[i])
     {
       newCentroids.col(assignments(0, i)) += dataset.col(i);
       ++counts(assignments(0, i));
+      // Reset for next iteration.
+      visited[i] = false;
     }
   }
 
@@ -541,7 +548,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
 */
         prunedPoints[index] = true;
         upperBounds[index] += clusterDistances[owner];
+        distances(0, index) += clusterDistances[owner];
         lastOwners[index] = owner;
+        distances(1, index) += clusterDistances[centroids.n_cols];
         lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
         prunedCentroids.col(owner) += dataset.col(index);
         prunedCounts(owner)++;
@@ -551,7 +560,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
       {
         prunedPoints[index] = true;
         upperBounds[index] += clusterDistances[owner];
+        distances(0, index) += clusterDistances[owner];
         lastOwners[index] = owner;
+        distances(1, index) += clusterDistances[centroids.n_cols];
         lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
         prunedCentroids.col(owner) += dataset.col(index);
         prunedCounts(owner)++;
@@ -559,24 +570,29 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
       else
       {
         // Attempt to tighten the lower bound.
-        upperBounds[index] = metric.Evaluate(centroids.col(owner),
+        distances(0, index) = metric.Evaluate(centroids.col(owner),
                                              dataset.col(index));
+        upperBounds[index] = distances(0, index);
         ++distanceCalculations;
         if (upperBounds[index] < lowerSecondBounds[index] -
             clusterDistances[centroids.n_cols])
         {
           prunedPoints[index] = true;
           lastOwners[index] = owner;
+          upperBounds[index] += clusterDistances[owner];
           lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
+          distances(1, index) += clusterDistances[centroids.n_cols];
           prunedCentroids.col(owner) += dataset.col(index);
           prunedCounts(owner)++;
         }
         else if (upperBounds[index] < 0.5 *
-                  interclusterDistances[newFromOldCentroids[owner]])
+                 interclusterDistances[newFromOldCentroids[owner]])
         {
           prunedPoints[index] = true;
           lastOwners[index] = owner;
+          upperBounds[index] += clusterDistances[owner];
           lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
+          distances(1, index) += clusterDistances[centroids.n_cols];
           prunedCentroids.col(owner) += dataset.col(index);
           prunedCounts(owner)++;
         }
@@ -584,6 +600,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
         {
           prunedPoints[index] = false;
           allPruned = false;
+          // Still update these anyway.
+          distances(0, index) += clusterDistances[owner];
+          distances(1, index) += clusterDistances[centroids.n_cols];
         }
       }
     }
@@ -607,6 +626,13 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
     }
   }
 
+  // Make sure all the point bounds are updated.
+  for (size_t i = 0; i < node.NumPoints(); ++i)
+  {
+    distances(0, node.Point(i)) += clusterDistances[centroids.n_cols];
+    distances(1, node.Point(i)) += clusterDistances[centroids.n_cols];
+  }
+
   if (node.Stat().FirstBound() != DBL_MAX)
     node.Stat().FirstBound() += clusterDistances[centroids.n_cols];
   if (node.Stat().SecondBound() != DBL_MAX)
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 2ded3ea..0183e9e 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -25,7 +25,8 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
                   arma::mat& distances,
                   MetricType& metric,
                   const std::vector<bool>& prunedPoints,
-                  const std::vector<size_t>& oldFromNewCentroids);
+                  const std::vector<size_t>& oldFromNewCentroids,
+                  std::vector<bool>& visited);
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
@@ -43,6 +44,8 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
   const std::vector<bool>& prunedPoints;
 
   const std::vector<size_t>& oldFromNewCentroids;
+
+  std::vector<bool>& visited;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 11c2c3e..7e4e48b 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -20,11 +20,13 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
     arma::mat& distances,
     MetricType& metric,
     const std::vector<bool>& prunedPoints,
-    const std::vector<size_t>& oldFromNewCentroids) :
+    const std::vector<size_t>& oldFromNewCentroids,
+    std::vector<bool>& visited) :
     neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
         TreeType>(centroids, dataset, neighbors, distances, metric),
     prunedPoints(prunedPoints),
-    oldFromNewCentroids(oldFromNewCentroids)
+    oldFromNewCentroids(oldFromNewCentroids),
+    visited(visited)
 {
   // Nothing to do.
 }
@@ -38,6 +40,9 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
   if (prunedPoints[queryIndex])
     return 0.0; // Returning 0 shouldn't be a problem.
 
+  // Any base cases imply that we will get a result.
+  visited[queryIndex] = true;
+
   // This is basically an inlined NeighborSearchRules::BaseCase(), but it
   // differs in that it applies the mappings to the results automatically.
   // We can also skip a check or two.



More information about the mlpack-git mailing list