[mlpack-svn] r16926 - mlpack/trunk/src/mlpack/methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 30 12:33:07 EDT 2014


Author: rcurtin
Date: Wed Jul 30 12:33:07 2014
New Revision: 16926

Log:
Further refactoring of KMeans.  Fix MaxVarianceNewCluster, and also change it so
that assignments are not passed in (since the inner KMeans loop doesn't really
keep track of these now).


Modified:
   mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp	Wed Jul 30 12:33:07 2014
@@ -36,12 +36,12 @@
    *
    * @return Number of points changed (0).
    */
-  template<typename MatType>
+  template<typename MetricType, typename MatType>
   static size_t EmptyCluster(const MatType& /* data */,
                              const size_t /* emptyCluster */,
                              const arma::mat& /* centroids */,
                              arma::Col<size_t>& /* clusterCounts */,
-                             arma::Col<size_t>& /* assignments */)
+                             MetricType& /* assignments */)
   {
     // Empty clusters are okay!  Do nothing.
     return 0;

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	Wed Jul 30 12:33:07 2014
@@ -122,19 +122,6 @@
       Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
           << assignments.n_elem << ") not the same size as the dataset (size "
           << data.n_cols << ")!" << std::endl;
-
-    // Calculate initial centroids.
-    counts.zeros(actualClusters);
-    centroids.zeros(data.n_rows, actualClusters);
-    for (size_t i = 0; i < data.n_cols; ++i)
-    {
-      centroids.col(assignments[i]) += data.col(i);
-      counts[assignments[i]]++;
-    }
-
-    for (size_t i = 0; i < actualClusters; ++i)
-      if (counts[i] != 0)
-        centroids.col(i) /= counts[i];
   }
   else if (initialCentroidGuess)
   {
@@ -147,36 +134,16 @@
       Log::Fatal << "KMeans::Cluster(): initial cluster centroids have wrong "
         << " dimensionality (" << centroids.n_rows << ", should be "
         << data.n_rows << ")!" << std::endl;
-
-    // If there were no problems, construct the initial assignments from the
-    // given centroids.
-    assignments.set_size(data.n_cols);
-    for (size_t i = 0; i < data.n_cols; ++i)
-    {
-      // Find the closest centroid to this point.
-      double minDistance = std::numeric_limits<double>::infinity();
-      size_t closestCluster = clusters; // Invalid value.
-
-      for (size_t j = 0; j < clusters; j++)
-      {
-        double distance = metric.Evaluate(data.col(i), centroids.col(j));
-
-        if (distance < minDistance)
-        {
-          minDistance = distance;
-          closestCluster = j;
-        }
-      }
-
-      // Assign the point to the closest cluster that we found.
-      assignments[i] = closestCluster;
-    }
   }
   else
   {
     // Use the partitioner to come up with the partition assignments.
     partitioner.Cluster(data, actualClusters, assignments);
+  }
 
+  // Calculate the initial centroids, if we need to.
+  if (!initialCentroidGuess || (initialAssignmentGuess && initialCentroidGuess))
+  {
     // Calculate initial centroids.
     counts.zeros(actualClusters);
     centroids.zeros(data.n_rows, actualClusters);
@@ -191,7 +158,6 @@
         centroids.col(i) /= counts[i];
   }
 
-  size_t changedAssignments = 0;
   size_t iteration = 0;
 
   LloydStepType<MetricType, MatType> lloydStep(data, metric);
@@ -210,9 +176,17 @@
     // If we are not allowing empty clusters, then check that all of our
     // clusters have points.
     for (size_t i = 0; i < actualClusters; i++)
+    {
       if (counts[i] == 0)
-        changedAssignments += emptyClusterAction.EmptyCluster(data, i,
-            centroids, counts, assignments);
+      {
+        Log::Debug << "Cluster " << i << " is empty.\n";
+        if (iteration % 2 == 0)
+          emptyClusterAction.EmptyCluster(data, i, centroidsOther, counts,
+              metric);
+        else
+          emptyClusterAction.EmptyCluster(data, i, centroids, counts, metric);
+      }
+    }
 
     // Calculate cluster distortion for this iteration.
     cNorm = 0.0;
@@ -228,9 +202,11 @@
 
   } while (cNorm > 1e-5 && iteration != maxIterations);
 
-  // Unfortunate copy that is sometimes necessary.
+  // If we ended on an even iteration, then the centroids are in the
+  // centroidsOther matrix, and we need to steal its memory (steal_mem() avoids
+  // a copy if possible).
   if (iteration % 2 == 0)
-    centroids = centroidsOther;
+    centroids.steal_mem(centroidsOther);
 
   if (iteration != maxIterations)
   {
@@ -244,6 +220,7 @@
   }
 
   // Calculate final assignments.
+  assignments.set_size(data.n_cols);
   for (size_t i = 0; i < data.n_cols; ++i)
   {
     // Find the closest centroid to this point.

Modified: mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp	Wed Jul 30 12:33:07 2014
@@ -37,12 +37,12 @@
    *
    * @return Number of points changed.
    */
-  template<typename MatType>
+  template<typename MetricType, typename MatType>
   static size_t EmptyCluster(const MatType& data,
                              const size_t emptyCluster,
                              arma::mat& centroids,
                              arma::Col<size_t>& clusterCounts,
-                             arma::Col<size_t>& assignments);
+                             MetricType& metric);
 };
 
 }; // namespace kmeans

Modified: mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp	Wed Jul 30 12:33:07 2014
@@ -16,24 +16,41 @@
 /**
  * Take action about an empty cluster.
  */
-template<typename MatType>
+template<typename MetricType, typename MatType>
 size_t MaxVarianceNewCluster::EmptyCluster(const MatType& data,
                                            const size_t emptyCluster,
                                            arma::mat& centroids,
                                            arma::Col<size_t>& clusterCounts,
-                                           arma::Col<size_t>& assignments)
+                                           MetricType& metric)
 {
   // First, we need to find the cluster with maximum variance (by which I mean
   // the sum of the covariance matrices).
   arma::vec variances;
   variances.zeros(clusterCounts.n_elem); // Start with 0.
+  arma::Col<size_t> assignments(data.n_cols);
 
   // Add the variance of each point's distance away from the cluster.  I think
   // this is the sensible thing to do.
   for (size_t i = 0; i < data.n_cols; ++i)
   {
-    variances[assignments[i]] += metric::SquaredEuclideanDistance::Evaluate(
-        data.col(i), centroids.col(assignments[i]));
+    // Find the closest centroid to this point.
+    double minDistance = std::numeric_limits<double>::infinity();
+    size_t closestCluster = centroids.n_cols; // Invalid value.
+
+    for (size_t j = 0; j < centroids.n_cols; j++)
+    { 
+      const double distance = metric.Evaluate(data.col(i), centroids.col(j));
+
+      if (distance < minDistance)
+      {
+        minDistance = distance;
+        closestCluster = j;
+      }
+    }
+
+    assignments[i] = closestCluster;
+    variances[closestCluster] += metric.Evaluate(data.col(i),
+        centroids.col(closestCluster));
   }
 
   // Divide by the number of points in the cluster to produce the variance.
@@ -55,8 +72,8 @@
   {
     if (assignments[i] == maxVarCluster)
     {
-      double distance = arma::as_scalar(
-          arma::var(data.col(i) - centroids.col(maxVarCluster)));
+      const double distance = metric.Evaluate(data.col(i),
+          centroids.col(maxVarCluster));
 
       if (distance > maxDistance)
       {
@@ -67,15 +84,15 @@
   }
 
   // Take that point and add it to the empty cluster.
-  centroids.col(maxVarCluster) *= (clusterCounts[maxVarCluster] /
-      --clusterCounts[maxVarCluster]);
-  centroids.col(maxVarCluster) -= (1.0 / clusterCounts[maxVarCluster]) *
+  centroids.col(maxVarCluster) *= (double(clusterCounts[maxVarCluster]) /
+      double(clusterCounts[maxVarCluster] - 1));
+  centroids.col(maxVarCluster) -= (1.0 / (clusterCounts[maxVarCluster] - 1.0)) *
       data.col(furthestPoint);
+  clusterCounts[maxVarCluster]--;
   clusterCounts[emptyCluster]++;
   centroids.col(emptyCluster) = arma::vec(data.col(furthestPoint));
   assignments[furthestPoint] = emptyCluster;
 
-
   // Output some debugging information.
   Log::Debug << "Point " << furthestPoint << " assigned to empty cluster " <<
       emptyCluster << ".\n";

Modified: mlpack/trunk/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp	Wed Jul 30 12:33:07 2014
@@ -60,6 +60,8 @@
   for (size_t i = 0; i < centroids.n_cols; ++i)
     if (counts(i) != 0)
       newCentroids.col(i) /= counts(i);
+    else
+      newCentroids.col(i).fill(DBL_MAX); // Invalid value.
 }
 
 } // namespace kmeans



More information about the mlpack-svn mailing list