[mlpack-git] master: Fix dual-tree k-means runtime bug. (a0f1dd5)

gitdub at mlpack.org gitdub at mlpack.org
Fri Mar 25 21:17:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c14e81a563d59de06f1a0fe7cab0b841ca9220ec...a0f1dd5632004b26cd3592e6ceaa455792df8c3a

>---------------------------------------------------------------

commit a0f1dd5632004b26cd3592e6ceaa455792df8c3a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Mar 25 21:16:31 2016 -0400

    Fix dual-tree k-means runtime bug.
    
    The residue was being calculated incorrectly, and as a result it would run until
    the maximum number of iterations.  Thanks to Erich Schubert for bringing this
    one up and pointing it out.
    
    The fascinating thing is that this did not trigger any test failures, because it
    still converged to the correct results---it just took way longer than it should
    have.


>---------------------------------------------------------------

a0f1dd5632004b26cd3592e6ceaa455792df8c3a
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |  2 +-
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 34 ++++++++++------------
 2 files changed, 16 insertions(+), 20 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 01fd8a1..400251b 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -118,7 +118,7 @@ class DualTreeKMeans
   void ExtractCentroids(Tree& node,
                         arma::mat& newCentroids,
                         arma::Col<size_t>& newCounts,
-                        arma::mat& centroids);
+                        const arma::mat& centroids);
 
   void CoalesceTree(Tree& node, const size_t child = 0);
   void DecoalesceTree(Tree& node);
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 48322ec..32e7392 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -21,9 +21,9 @@ namespace kmeans {
 //! Call the tree constructor that does mapping.
 template<typename TreeType>
 TreeType* BuildTree(
-    typename TreeType::Mat& dataset,
+    const typename TreeType::Mat& dataset,
     std::vector<size_t>& oldFromNew,
-    typename boost::enable_if_c<
+    const typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
     >::type = 0)
 {
@@ -96,11 +96,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::mat& newCentroids,
     arma::Col<size_t>& counts)
 {
-  // Build a tree on the centroids.
-  arma::mat oldCentroids(centroids); // Slow. :(
+  // Build a tree on the centroids.  This will make a copy if necessary, which
+  // is unfortunate, but I don't see a reasonable way around it.
   std::vector<size_t> oldFromNewCentroids;
-  Tree* centroidTree = BuildTree<Tree>(const_cast<MatType&>(centroids),
-      oldFromNewCentroids);
+  Tree* centroidTree = BuildTree<Tree>(centroids, oldFromNewCentroids);
 
   // Reset information in the tree, if we need to.
   if (iteration > 0)
@@ -134,7 +133,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
     Timer::Stop("knn");
 
-    UpdateTree(*tree, oldCentroids);
+    UpdateTree(*tree, centroids);
 
     for (size_t i = 0; i < dataset.n_cols; ++i)
       visited[i] = false;
@@ -147,7 +146,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   }
 
   // We won't use the AllkNN class here because we have our own set of rules.
-  lastIterationCentroids = oldCentroids;
+  lastIterationCentroids = centroids;
   typedef DualTreeKMeansRules<MetricType, Tree> RuleType;
   RuleType rules(centroidTree->Dataset(), dataset, assignments, upperBounds,
       lowerBounds, metric, prunedPoints, oldFromNewCentroids, visited);
@@ -171,27 +170,24 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   // Now we need to extract the clusters.
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
-  ExtractCentroids(*tree, newCentroids, counts, oldCentroids);
+  ExtractCentroids(*tree, newCentroids, counts, centroids);
 
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
   clusterDistances[centroids.n_cols] = 0.0;
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
-    // Get the mapping to the old cluster, if necessary.
-    const size_t old = (tree::TreeTraits<Tree>::RearrangesDataset) ?
-        oldFromNewCentroids[c] : c;
-    if (counts[old] == 0)
+    if (counts[c] == 0)
     {
-      newCentroids.col(old).fill(DBL_MAX);
-      clusterDistances[old] = 0;
+      newCentroids.col(c).fill(DBL_MAX);
+      clusterDistances[c] = 0;
     }
     else
     {
-      newCentroids.col(old) /= counts(old);
+      newCentroids.col(c) /= counts(c);
       const double movement = metric.Evaluate(centroids.col(c),
-          newCentroids.col(old));
-      clusterDistances[old] = movement;
+          newCentroids.col(c));
+      clusterDistances[c] = movement;
       residual += std::pow(movement, 2.0);
 
       if (movement > clusterDistances[centroids.n_cols])
@@ -481,7 +477,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     Tree& node,
     arma::mat& newCentroids,
     arma::Col<size_t>& newCounts,
-    arma::mat& centroids)
+    const arma::mat& centroids)
 {
   // Does this node own points?
   if ((node.Stat().Pruned() == newCentroids.n_cols) ||




More information about the mlpack-git mailing list