[mlpack-git] master: Refactor KMeans so that the actual Lloyd iteration step is separate, since there are many ways to do a Lloyd iteration. (ae2ddb1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:55:20 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit ae2ddb13a7b1070d6933e5728c17a34a87654d74
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Jul 29 22:04:27 2014 +0000

    Refactor KMeans so that the actual Lloyd iteration step is separate, since there
    are many ways to do a Lloyd iteration.


>---------------------------------------------------------------

ae2ddb13a7b1070d6933e5728c17a34a87654d74
 src/mlpack/methods/kmeans/CMakeLists.txt           |   2 +
 src/mlpack/methods/kmeans/allow_empty_clusters.hpp |   2 +-
 src/mlpack/methods/kmeans/kmeans.hpp               |  17 +-
 src/mlpack/methods/kmeans/kmeans_impl.hpp          | 183 ++++++++++++---------
 .../methods/kmeans/max_variance_new_cluster.hpp    |   2 +-
 .../kmeans/max_variance_new_cluster_impl.hpp       |   9 +-
 src/mlpack/methods/kmeans/naive_kmeans.hpp         |  45 +++++
 src/mlpack/methods/kmeans/naive_kmeans_impl.hpp    |  68 ++++++++
 8 files changed, 238 insertions(+), 90 deletions(-)

diff --git a/src/mlpack/methods/kmeans/CMakeLists.txt b/src/mlpack/methods/kmeans/CMakeLists.txt
index 85995a0..1e61c30 100644
--- a/src/mlpack/methods/kmeans/CMakeLists.txt
+++ b/src/mlpack/methods/kmeans/CMakeLists.txt
@@ -6,6 +6,8 @@ set(SOURCES
   kmeans_impl.hpp
   max_variance_new_cluster.hpp
   max_variance_new_cluster_impl.hpp
+  naive_kmeans.hpp
+  naive_kmeans_impl.hpp
   random_partition.hpp
   refined_start.hpp
   refined_start_impl.hpp
diff --git a/src/mlpack/methods/kmeans/allow_empty_clusters.hpp b/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
index ac956b3..ae476b7 100644
--- a/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
+++ b/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
@@ -39,7 +39,7 @@ class AllowEmptyClusters
   template<typename MatType>
   static size_t EmptyCluster(const MatType& /* data */,
                              const size_t /* emptyCluster */,
-                             const MatType& /* centroids */,
+                             const arma::mat& /* centroids */,
                              arma::Col<size_t>& /* clusterCounts */,
                              arma::Col<size_t>& /* assignments */)
   {
diff --git a/src/mlpack/methods/kmeans/kmeans.hpp b/src/mlpack/methods/kmeans/kmeans.hpp
index e060905..06c3e5e 100644
--- a/src/mlpack/methods/kmeans/kmeans.hpp
+++ b/src/mlpack/methods/kmeans/kmeans.hpp
@@ -12,6 +12,7 @@
 #include <mlpack/core/metrics/lmetric.hpp>
 #include "random_partition.hpp"
 #include "max_variance_new_cluster.hpp"
+#include "naive_kmeans.hpp"
 
 #include <mlpack/core/tree/binary_space_tree.hpp>
 
@@ -51,12 +52,16 @@ namespace kmeans /** K-Means clustering. */ {
  * @tparam EmptyClusterPolicy Policy for what to do on an empty cluster; must
  *     implement a default constructor and 'void EmptyCluster(const arma::mat&,
  *     arma::Col<size_t&)'.
+ * @tparam LloydStepType Implementation of single Lloyd step to use.
  *
- * @see RandomPartition, RefinedStart, AllowEmptyClusters, MaxVarianceNewCluster
+ * @see RandomPartition, RefinedStart, AllowEmptyClusters,
+ *      MaxVarianceNewCluster, NaiveKMeans
  */
 template<typename MetricType = metric::EuclideanDistance,
          typename InitialPartitionPolicy = RandomPartition,
-         typename EmptyClusterPolicy = MaxVarianceNewCluster>
+         typename EmptyClusterPolicy = MaxVarianceNewCluster,
+         template<class, class> class LloydStepType = NaiveKMeans,
+         typename MatType = arma::mat>
 class KMeans
 {
  public:
@@ -102,11 +107,10 @@ class KMeans
    * @param initialGuess If true, then it is assumed that assignments has a list
    *      of initial cluster assignments.
    */
-  template<typename MatType>
   void Cluster(const MatType& data,
                const size_t clusters,
                arma::Col<size_t>& assignments,
-               const bool initialGuess = false) const;
+               const bool initialGuess = false);
 
   /**
    * Perform k-means clustering on the data, returning a list of cluster
@@ -134,13 +138,12 @@ class KMeans
    * @param initialCentroidGuess If true, then it is assumed that centroids
    *      contains the initial centroids of each cluster.
    */
-  template<typename MatType>
   void Cluster(const MatType& data,
                const size_t clusters,
                arma::Col<size_t>& assignments,
-               MatType& centroids,
+               arma::mat& centroids,
                const bool initialAssignmentGuess = false,
-               const bool initialCentroidGuess = false) const;
+               const bool initialCentroidGuess = false);
 
   //! Return the overclustering factor.
   double OverclusteringFactor() const { return overclusteringFactor; }
diff --git a/src/mlpack/methods/kmeans/kmeans_impl.hpp b/src/mlpack/methods/kmeans/kmeans_impl.hpp
index 1ef8c6f..936831c 100644
--- a/src/mlpack/methods/kmeans/kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/kmeans_impl.hpp
@@ -10,9 +10,6 @@
 #include <mlpack/core/tree/mrkd_statistic.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 
-#include <stack>
-#include <limits>
-
 namespace mlpack {
 namespace kmeans {
 
@@ -21,11 +18,15 @@ namespace kmeans {
  */
 template<typename MetricType,
          typename InitialPartitionPolicy,
-         typename EmptyClusterPolicy>
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType,
+         typename MatType>
 KMeans<
     MetricType,
     InitialPartitionPolicy,
-    EmptyClusterPolicy>::
+    EmptyClusterPolicy,
+    LloydStepType,
+    MatType>::
 KMeans(const size_t maxIterations,
        const double overclusteringFactor,
        const MetricType metric,
@@ -57,18 +58,21 @@ KMeans(const size_t maxIterations,
  */
 template<typename MetricType,
          typename InitialPartitionPolicy,
-         typename EmptyClusterPolicy>
-template<typename MatType>
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType,
+         typename MatType>
 inline void KMeans<
     MetricType,
     InitialPartitionPolicy,
-    EmptyClusterPolicy>::
+    EmptyClusterPolicy,
+    LloydStepType,
+    MatType>::
 Cluster(const MatType& data,
         const size_t clusters,
         arma::Col<size_t>& assignments,
-        const bool initialGuess) const
+        const bool initialGuess)
 {
-  MatType centroids(data.n_rows, clusters);
+  arma::mat centroids(data.n_rows, clusters);
   Cluster(data, clusters, assignments, centroids, initialGuess);
 }
 
@@ -78,18 +82,21 @@ Cluster(const MatType& data,
  */
 template<typename MetricType,
          typename InitialPartitionPolicy,
-         typename EmptyClusterPolicy>
-template<typename MatType>
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType,
+         typename MatType>
 void KMeans<
     MetricType,
     InitialPartitionPolicy,
-    EmptyClusterPolicy>::
+    EmptyClusterPolicy,
+    LloydStepType,
+    MatType>::
 Cluster(const MatType& data,
         const size_t clusters,
         arma::Col<size_t>& assignments,
-        MatType& centroids,
+        arma::mat& centroids,
         const bool initialAssignmentGuess,
-        const bool initialCentroidGuess) const
+        const bool initialCentroidGuess)
 {
   // Make sure we have more points than clusters.
   if (clusters > data.n_cols)
@@ -105,6 +112,9 @@ Cluster(const MatType& data,
     actualClusters = clusters;
   }
 
+  // Counts of points in each cluster.
+  arma::Col<size_t> counts(actualClusters);
+
   // Now, the initial assignments.  First determine if they are necessary.
   if (initialAssignmentGuess)
   {
@@ -112,6 +122,19 @@ Cluster(const MatType& data,
       Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
           << assignments.n_elem << ") not the same size as the dataset (size "
           << data.n_cols << ")!" << std::endl;
+
+    // Calculate initial centroids.
+    counts.zeros(actualClusters);
+    centroids.zeros(data.n_rows, actualClusters);
+    for (size_t i = 0; i < data.n_cols; ++i)
+    {
+      centroids.col(assignments[i]) += data.col(i);
+      counts[assignments[i]]++;
+    }
+
+    for (size_t i = 0; i < actualClusters; ++i)
+      if (counts[i] != 0)
+        centroids.col(i) /= counts[i];
   }
   else if (initialCentroidGuess)
   {
@@ -153,65 +176,36 @@ Cluster(const MatType& data,
   {
     // Use the partitioner to come up with the partition assignments.
     partitioner.Cluster(data, actualClusters, assignments);
-  }
-
-  // Counts of points in each cluster.
-  arma::Col<size_t> counts(actualClusters);
-  counts.zeros();
 
-  // Resize to correct size.
-  centroids.set_size(data.n_rows, actualClusters);
+    // Calculate initial centroids.
+    counts.zeros(actualClusters);
+    centroids.zeros(data.n_rows, actualClusters);
+    for (size_t i = 0; i < data.n_cols; ++i)
+    {
+      centroids.col(assignments[i]) += data.col(i);
+      counts[assignments[i]]++;
+    }
 
-  // Set counts correctly.
-  for (size_t i = 0; i < assignments.n_elem; i++)
-    counts[assignments[i]]++;
+    for (size_t i = 0; i < actualClusters; ++i)
+      if (counts[i] != 0)
+        centroids.col(i) /= counts[i];
+  }
 
   size_t changedAssignments = 0;
   size_t iteration = 0;
-  do
-  {
-    // Update step.
-    // Calculate centroids based on given assignments.
-    centroids.zeros();
-
-    for (size_t i = 0; i < data.n_cols; i++)
-      centroids.col(assignments[i]) += data.col(i);
-
-    for (size_t i = 0; i < actualClusters; i++)
-      centroids.col(i) /= counts[i];
-
-    // Assignment step.
-    // Find the closest centroid to each point.  We will keep track of how many
-    // assignments change.  When no assignments change, we are done.
-    changedAssignments = 0;
-    for (size_t i = 0; i < data.n_cols; i++)
-    {
-      // Find the closest centroid to this point.
-      double minDistance = std::numeric_limits<double>::infinity();
-      size_t closestCluster = actualClusters; // Invalid value.
-
-      for (size_t j = 0; j < actualClusters; j++)
-      {
-        double distance = metric.Evaluate(data.col(i), centroids.col(j));
 
-        if (distance < minDistance)
-        {
-          minDistance = distance;
-          closestCluster = j;
-        }
-      }
+  LloydStepType<MetricType, MatType> lloydStep(data, metric);
+  arma::mat centroidsOther;
+  double cNorm;
 
-      // Reassign this point to the closest cluster.
-      if (assignments[i] != closestCluster)
-      {
-        // Update counts.
-        counts[assignments[i]]--;
-        counts[closestCluster]++;
-        // Update assignment.
-        assignments[i] = closestCluster;
-        changedAssignments++;
-      }
-    }
+  do
+  {
+    // We have two centroid matrices.  We don't want to copy anything, so,
+    // depending on the iteration number, we use a different centroid matrix...
+    if (iteration % 2 == 0)
+      lloydStep.Iterate(centroids, centroidsOther, counts);
+    else
+      lloydStep.Iterate(centroidsOther, centroids, counts);
 
     // If we are not allowing empty clusters, then check that all of our
     // clusters have points.
@@ -220,9 +214,23 @@ Cluster(const MatType& data,
         changedAssignments += emptyClusterAction.EmptyCluster(data, i,
             centroids, counts, assignments);
 
+    // Calculate cluster distortion for this iteration.
+    cNorm = 0.0;
+    for (size_t i = 0; i < centroids.n_cols; ++i)
+    {
+      const double dist = metric.Evaluate(centroids.col(i),
+          centroidsOther.col(i));
+      cNorm += std::pow(dist, 2.0);
+    }
+    cNorm = sqrt(cNorm);
+
     iteration++;
 
-  } while (changedAssignments > 0 && iteration != maxIterations);
+  } while (cNorm > 1e-5 && iteration != maxIterations);
+
+  // Unfortunate copy that is sometimes necessary.
+  if (iteration % 2 == 0)
+    centroids = centroidsOther;
 
   if (iteration != maxIterations)
   {
@@ -233,15 +241,28 @@ Cluster(const MatType& data,
   {
     Log::Debug << "KMeans::Cluster(): terminated after limit of " << iteration
         << " iterations." << std::endl;
+  }
+
+  // Calculate final assignments.
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    // Find the closest centroid to this point.
+    double minDistance = std::numeric_limits<double>::infinity();
+    size_t closestCluster = centroids.n_cols; // Invalid value.
 
-    // Recalculate final clusters.
-    centroids.zeros();
+    for (size_t j = 0; j < centroids.n_cols; j++)
+    {
+      const double distance = metric.Evaluate(data.col(i), centroids.col(j));
 
-    for (size_t i = 0; i < data.n_cols; i++)
-      centroids.col(assignments[i]) += data.col(i);
+      if (distance < minDistance)
+      {
+        minDistance = distance;
+        closestCluster = j;
+      }
+    }
 
-    for (size_t i = 0; i < actualClusters; i++)
-      centroids.col(i) /= counts[i];
+    Log::Assert(closestCluster != centroids.n_cols);
+    assignments[i] = closestCluster;
   }
 
   // If we have overclustered, we need to merge the nearest clusters.
@@ -372,17 +393,21 @@ Cluster(const MatType& data,
 
 template<typename MetricType,
          typename InitialPartitionPolicy,
-         typename EmptyClusterPolicy>
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType,
+         typename MatType>
 std::string KMeans<MetricType,
     InitialPartitionPolicy,
-    EmptyClusterPolicy>::ToString() const
+    EmptyClusterPolicy,
+    LloydStepType,
+    MatType>::ToString() const
 {
   std::ostringstream convert;
   convert << "KMeans [" << this << "]" << std::endl;
-  convert << "  Overclustering Factor: " << overclusteringFactor <<std::endl;
-  convert << "  Max Iterations: " << maxIterations <<std::endl;
+  convert << "  Overclustering Factor: " << overclusteringFactor << std::endl;
+  convert << "  Max Iterations: " << maxIterations << std::endl;
   convert << "  Metric: " << std::endl;
-  convert << mlpack::util::Indent(metric.ToString(),2);
+  convert << mlpack::util::Indent(metric.ToString(), 2);
   convert << std::endl;
   return convert.str();
 }
diff --git a/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp b/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
index 0715d01..af14ca1 100644
--- a/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
+++ b/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
@@ -40,7 +40,7 @@ class MaxVarianceNewCluster
   template<typename MatType>
   static size_t EmptyCluster(const MatType& data,
                              const size_t emptyCluster,
-                             const MatType& centroids,
+                             arma::mat& centroids,
                              arma::Col<size_t>& clusterCounts,
                              arma::Col<size_t>& assignments);
 };
diff --git a/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp b/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
index c97ef71..91b40f5 100644
--- a/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
+++ b/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
@@ -19,7 +19,7 @@ namespace kmeans {
 template<typename MatType>
 size_t MaxVarianceNewCluster::EmptyCluster(const MatType& data,
                                            const size_t emptyCluster,
-                                           const MatType& centroids,
+                                           arma::mat& centroids,
                                            arma::Col<size_t>& clusterCounts,
                                            arma::Col<size_t>& assignments)
 {
@@ -67,10 +67,15 @@ size_t MaxVarianceNewCluster::EmptyCluster(const MatType& data,
   }
 
   // Take that point and add it to the empty cluster.
-  clusterCounts[maxVarCluster]--;
+  centroids.col(maxVarCluster) *= (clusterCounts[maxVarCluster] /
+      --clusterCounts[maxVarCluster]);
+  centroids.col(maxVarCluster) -= (1.0 / clusterCounts[maxVarCluster]) *
+      data.col(furthestPoint);
   clusterCounts[emptyCluster]++;
+  centroids.col(emptyCluster) = arma::vec(data.col(furthestPoint));
   assignments[furthestPoint] = emptyCluster;
 
+
   // Output some debugging information.
   Log::Debug << "Point " << furthestPoint << " assigned to empty cluster " <<
       emptyCluster << ".\n";
diff --git a/src/mlpack/methods/kmeans/naive_kmeans.hpp b/src/mlpack/methods/kmeans/naive_kmeans.hpp
new file mode 100644
index 0000000..d0b986a
--- /dev/null
+++ b/src/mlpack/methods/kmeans/naive_kmeans.hpp
@@ -0,0 +1,45 @@
+/**
+ * @file naive_kmeans.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of a naively-implemented step of the Lloyd algorithm for
+ * k-means clustering.  This may still be the best choice for small datasets or
+ * datasets with very high dimensionality.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_NAIVE_KMEANS_HPP
+#define __MLPACK_METHODS_KMEANS_NAIVE_KMEANS_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType>
+class NaiveKMeans
+{
+ public:
+  NaiveKMeans(const MatType& dataset, MetricType& metric);
+
+  /**
+   * Run a single iteration of the Lloyd algorithm, updating the given centroids
+   * into the newCentroids matrix.
+   *
+   * @param centroids Current cluster centroids.
+   * @param newCentroids New cluster centroids.
+   */
+  void Iterate(const arma::mat& centroids,
+               arma::mat& newCentroids,
+               arma::Col<size_t>& counts);
+
+ private:
+  //! The dataset.
+  const MatType& dataset;
+  //! The instantiated metric.
+  MetricType& metric;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+// Include implementation.
+#include "naive_kmeans_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp b/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
new file mode 100644
index 0000000..4e9163c
--- /dev/null
+++ b/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
@@ -0,0 +1,68 @@
+/**
+ * @file naive_kmeans_impl.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of a naively-implemented step of the Lloyd algorithm for
+ * k-means clustering.  This may still be the best choice for small datasets or
+ * datasets with very high dimensionality.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_NAIVE_KMEANS_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_NAIVE_KMEANS_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "naive_kmeans.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType>
+NaiveKMeans<MetricType, MatType>::NaiveKMeans(const MatType& dataset,
+                                              MetricType& metric) :
+    dataset(dataset),
+    metric(metric)
+{ /* Nothing to do. */ }
+
+// Run a single iteration.
+template<typename MetricType, typename MatType>
+void NaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
+                                               arma::mat& newCentroids,
+                                               arma::Col<size_t>& counts)
+{
+  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  counts.zeros(centroids.n_cols);
+
+  // Find the closest centroid to each point and update the new centroids.
+  for (size_t i = 0; i < dataset.n_cols; i++)
+  {
+    // Find the closest centroid to this point.
+    double minDistance = std::numeric_limits<double>::infinity();
+    size_t closestCluster = centroids.n_cols; // Invalid value.
+
+    for (size_t j = 0; j < centroids.n_cols; j++)
+    {
+      const double distance = metric.Evaluate(dataset.col(i), centroids.col(j));
+
+      if (distance < minDistance)
+      {
+        minDistance = distance;
+        closestCluster = j;
+      }
+    }
+
+    Log::Assert(closestCluster != centroids.n_cols);
+
+    // We now have the minimum distance centroid index.  Update that centroid.
+    newCentroids.col(closestCluster) += dataset.col(i);
+    counts(closestCluster)++;
+  }
+
+  // Now normalize the centroid.
+  for (size_t i = 0; i < centroids.n_cols; ++i)
+    if (counts(i) != 0)
+      newCentroids.col(i) /= counts(i);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list