[mlpack-git] master: Refactor k-means significantly. Remove overclustering since I think nobody is using it (I don't think it's a very interesting technique) and it may be buggy. Speedups for the situation where only cluster centroids are desired. (695da4c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:00:39 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 695da4c329476bf3f1835bfa457605c3a31a1985
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Oct 9 18:56:06 2014 +0000

    Refactor k-means significantly.  Remove overclustering since I think nobody is
    using it (I don't think it's a very interesting technique) and it may be buggy.
    Speedups for the situation where only cluster centroids are desired.


>---------------------------------------------------------------

695da4c329476bf3f1835bfa457605c3a31a1985
 HISTORY.txt                                        |  11 +
 src/mlpack/methods/kmeans/CMakeLists.txt           |   2 +
 src/mlpack/methods/kmeans/allow_empty_clusters.hpp |  11 +-
 src/mlpack/methods/kmeans/kmeans.hpp               |  65 +++--
 src/mlpack/methods/kmeans/kmeans_impl.hpp          | 267 ++++++---------------
 src/mlpack/methods/kmeans/kmeans_main.cpp          | 243 ++++++++++---------
 src/mlpack/methods/kmeans/naive_kmeans.hpp         |  11 +-
 src/mlpack/methods/kmeans/naive_kmeans_impl.hpp    |  21 +-
 8 files changed, 283 insertions(+), 348 deletions(-)

diff --git a/HISTORY.txt b/HISTORY.txt
index b89635c..a8779bf 100644
--- a/HISTORY.txt
+++ b/HISTORY.txt
@@ -1,3 +1,14 @@
+????-??-??    mlpack 1.1.0
+
+  * Removed overclustering support from k-means because it is not well-tested,
+    may be buggy, and is (I think) unused.  If this was support you were using,
+    open a bug or get in touch with us; it would not be hard for us to
+    reimplement it.
+
+  * Refactored KMeans to allow different types of Lloyd iterations.
+
+  * Added an implementation of Elkan's k-means algorithm.
+
 2014-08-29    mlpack 1.0.10
 
   * Bugfix for NeighborSearch regression which caused very slow allknn/allkfn.
diff --git a/src/mlpack/methods/kmeans/CMakeLists.txt b/src/mlpack/methods/kmeans/CMakeLists.txt
index 1e61c30..99bc83b 100644
--- a/src/mlpack/methods/kmeans/CMakeLists.txt
+++ b/src/mlpack/methods/kmeans/CMakeLists.txt
@@ -2,6 +2,8 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   allow_empty_clusters.hpp
+  elkan_kmeans.hpp
+  elkan_kmeans_impl.hpp
   kmeans.hpp
   kmeans_impl.hpp
   max_variance_new_cluster.hpp
diff --git a/src/mlpack/methods/kmeans/allow_empty_clusters.hpp b/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
index f2af635..a8b0422 100644
--- a/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
+++ b/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
@@ -37,11 +37,12 @@ class AllowEmptyClusters
    * @return Number of points changed (0).
    */
   template<typename MetricType, typename MatType>
-  static size_t EmptyCluster(const MatType& /* data */,
-                             const size_t /* emptyCluster */,
-                             const arma::mat& /* centroids */,
-                             arma::Col<size_t>& /* clusterCounts */,
-                             MetricType& /* assignments */)
+  static inline force_inline size_t EmptyCluster(
+      const MatType& /* data */,
+      const size_t /* emptyCluster */,
+      const arma::mat& /* centroids */,
+      arma::Col<size_t>& /* clusterCounts */,
+      MetricType& /* metric */)
   {
     // Empty clusters are okay!  Do nothing.
     return 0;
diff --git a/src/mlpack/methods/kmeans/kmeans.hpp b/src/mlpack/methods/kmeans/kmeans.hpp
index 06c3e5e..866f13d 100644
--- a/src/mlpack/methods/kmeans/kmeans.hpp
+++ b/src/mlpack/methods/kmeans/kmeans.hpp
@@ -20,28 +20,28 @@ namespace mlpack {
 namespace kmeans /** K-Means clustering. */ {
 
 /**
- * This class implements K-Means clustering.  This implementation supports
- * overclustering, which means that more clusters than are requested will be
- * found; then, those clusters will be merged together to produce the desired
- * number of clusters.
+ * This class implements K-Means clustering, using a variety of possible
+ * implementations of Lloyd's algorithm.
  *
- * Two template parameters can (optionally) be supplied: the policy for how to
- * find the initial partition of the data, and the actions to be taken when an
- * empty cluster is encountered, as well as the distance metric to be used.
+ * Four template parameters can (optionally) be supplied: the distance metric to
+ * use, the policy for how to find the initial partition of the data, the
+ * actions to be taken when an empty cluster is encountered, and the
+ * implementation of a single Lloyd step to use.
  *
  * A simple example of how to run K-Means clustering is shown below.
  *
  * @code
  * extern arma::mat data; // Dataset we want to run K-Means on.
  * arma::Col<size_t> assignments; // Cluster assignments.
+ * arma::mat centroids; // Cluster centroids.
  *
  * KMeans<> k; // Default options.
- * k.Cluster(data, 3, assignments); // 3 clusters.
+ * k.Cluster(data, 3, assignments, centroids); // 3 clusters.
  *
- * // Cluster using the Manhattan distance, 100 iterations maximum, and an
- * // overclustering factor of 4.0.
- * KMeans<metric::ManhattanDistance> k(100, 4.0);
- * k.Cluster(data, 6, assignments); // 6 clusters.
+ * // Cluster using the Manhattan distance, 100 iterations maximum, saving only
+ * // the centroids.
+ * KMeans<metric::ManhattanDistance> k(100);
+ * k.Cluster(data, 6, centroids); // 6 clusters.
  * @endcode
  *
  * @tparam MetricType The distance metric to use for this KMeans; see
@@ -55,7 +55,7 @@ namespace kmeans /** K-Means clustering. */ {
  * @tparam LloydStepType Implementation of single Lloyd step to use.
  *
  * @see RandomPartition, RefinedStart, AllowEmptyClusters,
- *      MaxVarianceNewCluster, NaiveKMeans
+ *      MaxVarianceNewCluster, NaiveKMeans, ElkanKMeans
  */
 template<typename MetricType = metric::EuclideanDistance,
          typename InitialPartitionPolicy = RandomPartition,
@@ -71,15 +71,8 @@ class KMeans
    * the performance of K-Means, including "overclustering" and disallowing
    * empty clusters.
    *
-   * The overclustering factor controls how many clusters are
-   * actually found; for instance, with an overclustering factor of 4, if
-   * K-Means is run to find 3 clusters, it will actually find 12, then merge the
-   * nearest clusters until only 3 are left.
-   *
    * @param maxIterations Maximum number of iterations allowed before giving up
    *     (0 is valid, but the algorithm may never terminate).
-   * @param overclusteringFactor Factor controlling how many extra clusters are
-   *     found and then merged to get the desired number of clusters.
    * @param metric Optional MetricType object; for when the metric has state
    *     it needs to store.
    * @param partitioner Optional InitialPartitionPolicy object; for when a
@@ -88,7 +81,6 @@ class KMeans
    *     specially initialized empty cluster policy is required.
    */
   KMeans(const size_t maxIterations = 1000,
-         const double overclusteringFactor = 1.0,
          const MetricType metric = MetricType(),
          const InitialPartitionPolicy partitioner = InitialPartitionPolicy(),
          const EmptyClusterPolicy emptyClusterAction = EmptyClusterPolicy());
@@ -113,6 +105,24 @@ class KMeans
                const bool initialGuess = false);
 
   /**
+   * Perform k-means clustering on the data, returning the centroids of each
+   * cluster in the centroids matrix.  Optionally, the initial centroids can be
+   * specified by filling the centroids matrix with the initial centroids and
+   * specifying initialGuess = true.
+   *
+   * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
+   * @param data Dataset to cluster.
+   * @param clusters Number of clusters to compute.
+   * @param centroids Matrix in which centroids are stored.
+   * @param initialGuess If true, then it is assumed that centroids contains the
+   *      initial cluster centroids.
+   */
+  void Cluster(const MatType& data,
+               const size_t clusters,
+               arma::mat& centroids,
+               const bool initialGuess = false);
+
+  /**
    * Perform k-means clustering on the data, returning a list of cluster
    * assignments and also the centroids of each cluster.  Optionally, the vector
    * of assignments can be set to an initial guess of the cluster assignments;
@@ -122,12 +132,6 @@ class KMeans
    * supersedes initialCentroidGuess, so if both are set to true, the
    * assignments vector is used.
    *
-   * Note that if the overclustering factor is greater than 1, the centroids
-   * matrix will be resized in the method.  Regardless of the overclustering
-   * factor, the centroid guess matrix (if initialCentroidGuess is set to true)
-   * should have the same number of rows as the data matrix, and number of
-   * columns equal to 'clusters'.
-   *
    * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
    * @param data Dataset to cluster.
    * @param clusters Number of clusters to compute.
@@ -145,11 +149,6 @@ class KMeans
                const bool initialAssignmentGuess = false,
                const bool initialCentroidGuess = false);
 
-  //! Return the overclustering factor.
-  double OverclusteringFactor() const { return overclusteringFactor; }
-  //! Set the overclustering factor.  Must be greater than 1.
-  double& OverclusteringFactor() { return overclusteringFactor; }
-
   //! Get the maximum number of iterations.
   size_t MaxIterations() const { return maxIterations; }
   //! Set the maximum number of iterations.
@@ -175,8 +174,6 @@ class KMeans
   std::string ToString() const;
 
  private:
-  //! Factor controlling how many clusters are actually found.
-  double overclusteringFactor;
   //! Maximum number of iterations before giving up.
   size_t maxIterations;
   //! Instantiated distance metric.
diff --git a/src/mlpack/methods/kmeans/kmeans_impl.hpp b/src/mlpack/methods/kmeans/kmeans_impl.hpp
index 4f0d164..87b42b2 100644
--- a/src/mlpack/methods/kmeans/kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/kmeans_impl.hpp
@@ -28,7 +28,6 @@ KMeans<
     LloydStepType,
     MatType>::
 KMeans(const size_t maxIterations,
-       const double overclusteringFactor,
        const MetricType metric,
        const InitialPartitionPolicy partitioner,
        const EmptyClusterPolicy emptyClusterAction) :
@@ -37,17 +36,7 @@ KMeans(const size_t maxIterations,
     partitioner(partitioner),
     emptyClusterAction(emptyClusterAction)
 {
-  // Validate overclustering factor.
-  if (overclusteringFactor < 1.0)
-  {
-    Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
-        << overclusteringFactor << " given). Setting factor to 1.0.\n";
-    this->overclusteringFactor = 1.0;
-  }
-  else
-  {
-    this->overclusteringFactor = overclusteringFactor;
-  }
+  // Nothing to do.
 }
 
 /**
@@ -93,37 +82,16 @@ void KMeans<
     MatType>::
 Cluster(const MatType& data,
         const size_t clusters,
-        arma::Col<size_t>& assignments,
         arma::mat& centroids,
-        const bool initialAssignmentGuess,
-        const bool initialCentroidGuess)
+        const bool initialGuess)
 {
   // Make sure we have more points than clusters.
   if (clusters > data.n_cols)
     Log::Warn << "KMeans::Cluster(): more clusters requested than points given."
         << std::endl;
 
-  // Make sure our overclustering factor is valid.
-  size_t actualClusters = size_t(overclusteringFactor * clusters);
-  if (actualClusters > data.n_cols && overclusteringFactor != 1.0)
-  {
-    Log::Warn << "KMeans::Cluster(): overclustering factor is too large.  No "
-        << "overclustering will be done." << std::endl;
-    actualClusters = clusters;
-  }
-
-  // Counts of points in each cluster.
-  arma::Col<size_t> counts(actualClusters);
-
-  // Now, the initial assignments.  First determine if they are necessary.
-  if (initialAssignmentGuess)
-  {
-    if (assignments.n_elem != data.n_cols)
-      Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
-          << assignments.n_elem << ") not the same size as the dataset (size "
-          << data.n_cols << ")!" << std::endl;
-  }
-  else if (initialCentroidGuess)
+  // Check validity of initial guess.
+  if (initialGuess)
   {
     if (centroids.n_cols != clusters)
       Log::Fatal << "KMeans::Cluster(): wrong number of initial cluster "
@@ -135,29 +103,35 @@ Cluster(const MatType& data,
         << " dimensionality (" << centroids.n_rows << ", should be "
         << data.n_rows << ")!" << std::endl;
   }
-  else
-  {
-    // Use the partitioner to come up with the partition assignments.
-    partitioner.Cluster(data, actualClusters, assignments);
-  }
 
-  // Calculate the initial centroids, if we need to.
-  if (!initialCentroidGuess || (initialAssignmentGuess && initialCentroidGuess))
+  // Use the partitioner to come up with the partition assignments and calculate
+  // the initial centroids.
+  if (!initialGuess)
   {
+    // The partitioner gives assignments, so we need to calculate centroids from
+    // those assignments.  This is probably not the most efficient way to do
+    // this, so maybe refactoring should be considered in the future.
+    arma::Col<size_t> assignments;
+    partitioner.Cluster(data, clusters, assignments);
+
     // Calculate initial centroids.
-    counts.zeros(actualClusters);
-    centroids.zeros(data.n_rows, actualClusters);
+    arma::Col<size_t> counts;
+    counts.zeros(clusters);
+    centroids.zeros(data.n_rows, clusters);
     for (size_t i = 0; i < data.n_cols; ++i)
     {
       centroids.col(assignments[i]) += arma::vec(data.col(i));
       counts[assignments[i]]++;
     }
 
-    for (size_t i = 0; i < actualClusters; ++i)
+    for (size_t i = 0; i < clusters; ++i)
       if (counts[i] != 0)
         centroids.col(i) /= counts[i];
   }
 
+  // Counts of points in each cluster.
+  arma::Col<size_t> counts(clusters);
+
   size_t iteration = 0;
 
   LloydStepType<MetricType, MatType> lloydStep(data, metric);
@@ -169,17 +143,17 @@ Cluster(const MatType& data,
     // We have two centroid matrices.  We don't want to copy anything, so,
     // depending on the iteration number, we use a different centroid matrix...
     if (iteration % 2 == 0)
-      lloydStep.Iterate(centroids, centroidsOther, counts);
+      cNorm = lloydStep.Iterate(centroids, centroidsOther, counts);
     else
-      lloydStep.Iterate(centroidsOther, centroids, counts);
+      cNorm = lloydStep.Iterate(centroidsOther, centroids, counts);
 
     // If we are not allowing empty clusters, then check that all of our
     // clusters have points.
-    for (size_t i = 0; i < actualClusters; i++)
+    for (size_t i = 0; i < clusters; i++)
     {
       if (counts[i] == 0)
       {
-        Log::Debug << "Cluster " << i << " is empty.\n";
+        Log::Info << "Cluster " << i << " is empty.\n";
         if (iteration % 2 == 0)
           emptyClusterAction.EmptyCluster(data, i, centroidsOther, counts,
               metric);
@@ -188,17 +162,9 @@ Cluster(const MatType& data,
       }
     }
 
-    // Calculate cluster distortion for this iteration.
-    cNorm = 0.0;
-    for (size_t i = 0; i < centroids.n_cols; ++i)
-    {
-      const double dist = metric.Evaluate(centroids.col(i),
-          centroidsOther.col(i));
-      cNorm += std::pow(dist, 2.0);
-    }
-    cNorm = sqrt(cNorm);
-
     iteration++;
+    Log::Info << "KMeans::Cluster(): iteration " << iteration << ", residual "
+        << cNorm << ".\n";
 
   } while (cNorm > 1e-5 && iteration != maxIterations);
 
@@ -210,14 +176,65 @@ Cluster(const MatType& data,
 
   if (iteration != maxIterations)
   {
-    Log::Debug << "KMeans::Cluster(): converged after " << iteration
+    Log::Info << "KMeans::Cluster(): converged after " << iteration
         << " iterations." << std::endl;
   }
   else
   {
-    Log::Debug << "KMeans::Cluster(): terminated after limit of " << iteration
+    Log::Info << "KMeans::Cluster(): terminated after limit of " << iteration
         << " iterations." << std::endl;
   }
+  Log::Info << lloydStep.DistanceCalculations() << " distance calculations."
+      << std::endl;
+}
+
+/**
+ * Perform k-means clustering on the data, returning a list of cluster
+ * assignments and the centroids of each cluster.
+ */
+template<typename MetricType,
+         typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType,
+         typename MatType>
+void KMeans<
+    MetricType,
+    InitialPartitionPolicy,
+    EmptyClusterPolicy,
+    LloydStepType,
+    MatType>::
+Cluster(const MatType& data,
+        const size_t clusters,
+        arma::Col<size_t>& assignments,
+        arma::mat& centroids,
+        const bool initialAssignmentGuess,
+        const bool initialCentroidGuess)
+{
+  // Now, the initial assignments.  First determine if they are necessary.
+  if (initialAssignmentGuess)
+  {
+    if (assignments.n_elem != data.n_cols)
+      Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
+          << assignments.n_elem << ") not the same size as the dataset (size "
+          << data.n_cols << ")!" << std::endl;
+
+    // Calculate initial centroids.
+    arma::Col<size_t> counts;
+    counts.zeros(clusters);
+    centroids.zeros(data.n_rows, clusters);
+    for (size_t i = 0; i < data.n_cols; ++i)
+    {
+      centroids.col(assignments[i]) += arma::vec(data.col(i));
+      counts[assignments[i]]++;
+    }
+
+    for (size_t i = 0; i < clusters; ++i)
+      if (counts[i] != 0)
+        centroids.col(i) /= counts[i];
+  }
+
+  Cluster(data, clusters, centroids,
+      initialAssignmentGuess || initialCentroidGuess);
 
   // Calculate final assignments.
   assignments.set_size(data.n_cols);
@@ -241,131 +258,6 @@ Cluster(const MatType& data,
     Log::Assert(closestCluster != centroids.n_cols);
     assignments[i] = closestCluster;
   }
-
-  // If we have overclustered, we need to merge the nearest clusters.
-  if (actualClusters != clusters)
-  {
-    // Generate a list of all the clusters' distances from each other.  This
-    // list will become mangled and unused as the number of clusters decreases.
-    size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
-    size_t clustersLeft = actualClusters;
-    arma::vec distances(numDistances);
-    arma::Col<size_t> firstCluster(numDistances);
-    arma::Col<size_t> secondCluster(numDistances);
-
-    // Keep the mappings of clusters that we are changing.
-    arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
-        actualClusters - 1, actualClusters);
-
-    size_t i = 0;
-    for (size_t first = 0; first < actualClusters; first++)
-    {
-      for (size_t second = first + 1; second < actualClusters; second++)
-      {
-        distances(i) = metric.Evaluate(centroids.col(first),
-                                       centroids.col(second));
-        firstCluster(i) = first;
-        secondCluster(i) = second;
-        i++;
-      }
-    }
-
-    while (clustersLeft != clusters)
-    {
-      arma::uword minIndex;
-      distances.min(minIndex);
-
-      // Now we merge the clusters which that distance belongs to.
-      size_t first = firstCluster(minIndex);
-      size_t second = secondCluster(minIndex);
-      for (size_t j = 0; j < assignments.n_elem; j++)
-        if (assignments(j) == second)
-          assignments(j) = first;
-
-      // Now merge the centroids.
-      centroids.col(first) *= counts[first];
-      centroids.col(first) += (counts[second] * centroids.col(second));
-      centroids.col(first) /= (counts[first] + counts[second]);
-
-      // Update the counts.
-      counts[first] += counts[second];
-      counts[second] = 0;
-
-      // Now update all the relevant distances.
-      // First the distances where either cluster is the second cluster.
-      for (size_t cluster = 0; cluster < second; cluster++)
-      {
-        // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
-        // m is the cluster we are trying to offset to.
-        size_t offset = (size_t) (((actualClusters - 1) * cluster)
-            + (cluster - pow(cluster, 2.0)) / 2) - 1;
-
-        // See if we need to update the distance from this cluster to the first
-        // cluster.
-        if (cluster < first)
-        {
-          // Make sure it isn't already DBL_MAX.
-          if (distances(offset + (first - cluster)) != DBL_MAX)
-            distances(offset + (first - cluster)) = metric.Evaluate(
-                centroids.col(first), centroids.col(cluster));
-        }
-
-        distances(offset + (second - cluster)) = DBL_MAX;
-      }
-
-      // Now the distances where the first cluster is the first cluster.
-      size_t offset = (size_t) (((actualClusters - 1) * first)
-          + (first - pow(first, 2.0)) / 2) - 1;
-      for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
-      {
-        // Make sure it isn't already DBL_MAX.
-        if (distances(offset + (cluster - first)) != DBL_MAX)
-        {
-          distances(offset + (cluster - first)) = metric.Evaluate(
-              centroids.col(first), centroids.col(cluster));
-        }
-      }
-
-      // Max the distance between the first and second clusters.
-      distances(offset + (second - first)) = DBL_MAX;
-
-      // Now max the distances for the second cluster (which no longer has
-      // anything in it).
-      offset = (size_t) (((actualClusters - 1) * second)
-          + (second - pow(second, 2.0)) / 2) - 1;
-      for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
-        distances(offset + (cluster - second)) = DBL_MAX;
-
-      clustersLeft--;
-
-      // Update the cluster mappings.
-      mappings(second) = first;
-      // Also update any mappings that were pointed at the previous cluster.
-      for (size_t cluster = 0; cluster < actualClusters; cluster++)
-        if (mappings(cluster) == second)
-          mappings(cluster) = first;
-    }
-
-    // Now remap the mappings down to the smallest possible numbers.
-    // Could this process be sped up?
-    arma::Col<size_t> remappings(actualClusters);
-    remappings.fill(actualClusters);
-    size_t remap = 0; // Counter variable.
-    for (size_t cluster = 0; cluster < actualClusters; cluster++)
-    {
-      // If the mapping of the current cluster has not been assigned a value
-      // yet, we will assign it a cluster number.
-      if (remappings(mappings(cluster)) == actualClusters)
-      {
-        remappings(mappings(cluster)) = remap;
-        remap++;
-      }
-    }
-
-    // Fix the assignments using the mappings we created.
-    for (size_t j = 0; j < assignments.n_elem; j++)
-      assignments(j) = remappings(mappings(assignments(j)));
-  }
 }
 
 template<typename MetricType,
@@ -381,7 +273,6 @@ std::string KMeans<MetricType,
 {
   std::ostringstream convert;
   convert << "KMeans [" << this << "]" << std::endl;
-  convert << "  Overclustering Factor: " << overclusteringFactor << std::endl;
   convert << "  Max Iterations: " << maxIterations << std::endl;
   convert << "  Metric: " << std::endl;
   convert << mlpack::util::Indent(metric.ToString(), 2);
diff --git a/src/mlpack/methods/kmeans/kmeans_main.cpp b/src/mlpack/methods/kmeans/kmeans_main.cpp
index ef0b7bb..15003da 100644
--- a/src/mlpack/methods/kmeans/kmeans_main.cpp
+++ b/src/mlpack/methods/kmeans/kmeans_main.cpp
@@ -9,6 +9,7 @@
 #include "kmeans.hpp"
 #include "allow_empty_clusters.hpp"
 #include "refined_start.hpp"
+#include "elkan_kmeans.hpp"
 
 using namespace mlpack;
 using namespace mlpack::kmeans;
@@ -28,27 +29,28 @@ PROGRAM_INFO("K-Means Clustering", "This program performs K-Means clustering "
     " random samples of the dataset; to specify the number of samples, the "
     "--samples parameter is used, and to specify the percentage of the dataset "
     "to be used in each sample, the --percentage parameter is used (it should "
-    "be a value between 0.0 and 1.0).\n");
+    "be a value between 0.0 and 1.0)."
+    "\n\n"
+    "As of October 2014, the --overclustering option has been removed.  If you "
+    "want this support back, let us know -- file a bug at "
+    "http://www.mlpack.org/trac/ or get in touch through another means.");
 
 // Required options.
 PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
 PARAM_INT_REQ("clusters", "Number of clusters to find.", "c");
 
 // Output options.
-PARAM_FLAG("in_place", "If specified, a column of the learned cluster "
+PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
     "assignments will be added to the input dataset file.  In this case, "
-    "--outputFile is not necessary.", "P");
+    "--outputFile is overridden.", "P");
 PARAM_STRING("output_file", "File to write output labels or labeled data to.",
-    "o", "output.csv");
+    "o", "");
 PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
     " be written to the given file.", "C", "");
 
 // k-means configuration options.
 PARAM_FLAG("allow_empty_clusters", "Allow empty clusters to be created.", "e");
 PARAM_FLAG("labels_only", "Only output labels into output file.", "l");
-PARAM_DOUBLE("overclustering", "Finds (overclustering * clusters) clusters, "
-    "then merges them together until only the desired number of clusters are "
-    "left.", "O", 1.0);
 PARAM_INT("max_iterations", "Maximum number of iterations before K-Means "
     "terminates.", "m", 1000);
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
@@ -62,7 +64,23 @@ PARAM_INT("samplings", "Number of samplings to perform for refined start (use "
     "when --refined_start is specified).", "S", 100);
 PARAM_DOUBLE("percentage", "Percentage of dataset to use for each refined start"
     " sampling (use when --refined_start is specified).", "p", 0.02);
+PARAM_FLAG("elkan", "Use Elkan's algorithm.", "E");
+
+// Given the type of initial partition policy, figure out the empty cluster
+// policy and run k-means.
+template<typename InitialPartitionPolicy>
+void FindEmptyClusterPolicy(const InitialPartitionPolicy& ipp);
 
+// Given the initial partitionining policy and empty cluster policy, figure out
+// the Lloyd iteration step type and run k-means.
+template<typename InitialPartitionPolicy, typename EmptyClusterPolicy>
+void FindLloydStepType(const InitialPartitionPolicy& ipp);
+
+// Given the template parameters, sanitize/load input and run k-means.
+template<typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType>
+void RunKMeans(const InitialPartitionPolicy& ipp);
 
 int main(int argc, char** argv)
 {
@@ -74,7 +92,58 @@ int main(int argc, char** argv)
   else
     math::RandomSeed((size_t) std::time(NULL));
 
-  // Now do validation of options.
+  // Now, start building the KMeans type that we'll be using.  Start with the
+  // initial partition policy.  The call to FindEmptyClusterPolicy<> results in
+  // a call to RunKMeans<> and the algorithm is completed.
+  if (CLI::HasParam("refined_start"))
+  {
+    const int samplings = CLI::GetParam<int>("samplings");
+    const double percentage = CLI::GetParam<double>("percentage");
+
+    if (samplings < 0)
+      Log::Fatal << "Number of samplings (" << samplings << ") must be "
+          << "greater than 0!" << endl;
+    if (percentage <= 0.0 || percentage > 1.0)
+      Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+          << "greater than 0.0 and less than or equal to 1.0!" << endl;
+
+    FindEmptyClusterPolicy<RefinedStart>(RefinedStart(samplings, percentage));
+  }
+  else
+  {
+    FindEmptyClusterPolicy<RandomPartition>(RandomPartition());
+  }
+}
+
+// Given the type of initial partition policy, figure out the empty cluster
+// policy and run k-means.
+template<typename InitialPartitionPolicy>
+void FindEmptyClusterPolicy(const InitialPartitionPolicy& ipp)
+{
+  if (CLI::HasParam("allow_empty_clusters"))
+    FindLloydStepType<InitialPartitionPolicy, AllowEmptyClusters>(ipp);
+  else
+    FindLloydStepType<InitialPartitionPolicy, MaxVarianceNewCluster>(ipp);
+}
+
+// Given the initial partitionining policy and empty cluster policy, figure out
+// the Lloyd iteration step type and run k-means.
+template<typename InitialPartitionPolicy, typename EmptyClusterPolicy>
+void FindLloydStepType(const InitialPartitionPolicy& ipp)
+{
+  if (CLI::HasParam("elkan"))
+    RunKMeans<InitialPartitionPolicy, EmptyClusterPolicy, ElkanKMeans>(ipp);
+  else
+    RunKMeans<InitialPartitionPolicy, EmptyClusterPolicy, NaiveKMeans>(ipp);
+}
+
+// Given the template parameters, sanitize/load input and run k-means.
+template<typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy,
+         template<class, class> class LloydStepType>
+void RunKMeans(const InitialPartitionPolicy& ipp)
+{
+  // Now, do validation of input options.
   const string inputFile = CLI::GetParam<string>("inputFile");
   const int clusters = CLI::GetParam<int>("clusters");
   if (clusters < 1)
@@ -90,27 +159,18 @@ int main(int argc, char** argv)
         ")! Must be greater than or equal to 0." << endl;
   }
 
-  const double overclustering = CLI::GetParam<double>("overclustering");
-  if (overclustering < 1)
-  {
-    Log::Fatal << "Invalid value for overclustering (" << overclustering <<
-        ")! Must be greater than or equal to 1." << endl;
-  }
-
   // Make sure we have an output file if we're not doing the work in-place.
-  if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file"))
+  if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file") &&
+      !CLI::HasParam("centroid_file"))
   {
-    Log::Fatal << "--outputFile not specified (and --in_place not set)."
-        << endl;
+    Log::Warn << "--output_file, --in_place, and --centroid_file are not set; "
+        << "no results will be saved." << std::endl;
   }
 
   // Load our dataset.
   arma::mat dataset;
   data::Load(inputFile, dataset, true); // Fatal upon failure.
 
-  // Now create the KMeans object.  Because we could be using different types,
-  // it gets a little weird...
-  arma::Col<size_t> assignments;
   arma::mat centroids;
 
   const bool initialCentroidGuess = CLI::HasParam("initial_centroids");
@@ -128,112 +188,67 @@ int main(int argc, char** argv)
           initialCentroidsFile << "'." << endl;
   }
 
-  if (CLI::HasParam("allow_empty_clusters"))
-  {
-    if (CLI::HasParam("refined_start"))
-    {
-      const int samplings = CLI::GetParam<int>("samplings");
-      const double percentage = CLI::GetParam<double>("percentage");
-
-      if (samplings < 0)
-        Log::Fatal << "Number of samplings (" << samplings << ") must be "
-            << "greater than 0!" << endl;
-      if (percentage <= 0.0 || percentage > 1.0)
-        Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
-            << "greater than 0.0 and less than or equal to 1.0!" << endl;
-
-      KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
-          k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
-          RefinedStart(samplings, percentage));
-
-      Timer::Start("clustering");
-      k.Cluster(dataset, clusters, assignments, centroids);
-      Timer::Stop("clustering");
-    }
-    else
-    {
-      KMeans<metric::SquaredEuclideanDistance, RandomPartition,
-          AllowEmptyClusters> k(maxIterations, overclustering);
-
-      Timer::Start("clustering");
-      k.Cluster(dataset, clusters, assignments, centroids, false,
-          initialCentroidGuess);
-      Timer::Stop("clustering");
-    }
-  }
-  else
-  {
-    if (CLI::HasParam("refined_start"))
-    {
-      const int samplings = CLI::GetParam<int>("samplings");
-      const double percentage = CLI::GetParam<double>("percentage");
-
-      if (samplings < 0)
-        Log::Fatal << "Number of samplings (" << samplings << ") must be "
-            << "greater than 0!" << endl;
-      if (percentage <= 0.0 || percentage > 1.0)
-        Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
-            << "greater than 0.0 and less than or equal to 1.0!" << endl;
-
-      KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
-          k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
-          RefinedStart(samplings, percentage));
-
-      Timer::Start("clustering");
-      k.Cluster(dataset, clusters, assignments, centroids);
-      Timer::Stop("clustering");
-    }
-    else
-    {
-      KMeans<> k(maxIterations, overclustering);
-
-      Timer::Start("clustering");
-      k.Cluster(dataset, clusters, assignments, centroids, false,
-          initialCentroidGuess);
-      Timer::Stop("clustering");
-    }
-  }
-
-  // Now figure out what to do with our results.
-  if (CLI::HasParam("in_place"))
-  {
-    // Add the column of assignments to the dataset; but we have to convert them
-    // to type double first.
-    arma::vec converted(assignments.n_elem);
-    for (size_t i = 0; i < assignments.n_elem; i++)
-      converted(i) = (double) assignments(i);
-
-    dataset.insert_rows(dataset.n_rows, trans(converted));
+  KMeans<metric::EuclideanDistance,
+         InitialPartitionPolicy,
+         EmptyClusterPolicy,
+         LloydStepType> kmeans(maxIterations, metric::EuclideanDistance(), ipp);
 
-    // Save the dataset.
-    data::Save(inputFile, dataset);
-  }
-  else
+  if (CLI::HasParam("output_file") || CLI::HasParam("in_place"))
   {
-    if (CLI::HasParam("labels_only"))
-    {
-      // Save only the labels.
-      string outputFile = CLI::GetParam<string>("output_file");
-      arma::Mat<size_t> output = trans(assignments);
-      data::Save(outputFile, output);
-    }
-    else
+    // We need to get the assignments.
+    arma::Col<size_t> assignments;
+    Timer::Start("clustering");
+    kmeans.Cluster(dataset, clusters, assignments, centroids,
+        false, initialCentroidGuess);
+    Timer::Stop("clustering");
+
+    // Now figure out what to do with our results.
+    if (CLI::HasParam("in_place"))
     {
-      // Convert the assignments to doubles.
+      // Add the column of assignments to the dataset; but we have to convert
+      // them to type double first.
       arma::vec converted(assignments.n_elem);
       for (size_t i = 0; i < assignments.n_elem; i++)
         converted(i) = (double) assignments(i);
 
       dataset.insert_rows(dataset.n_rows, trans(converted));
 
-      // Now save, in the different file.
-      string outputFile = CLI::GetParam<string>("output_file");
-      data::Save(outputFile, dataset);
+      // Save the dataset.
+      data::Save(inputFile, dataset);
+    }
+    else
+    {
+      if (CLI::HasParam("labels_only"))
+      {
+        // Save only the labels.
+        string outputFile = CLI::GetParam<string>("output_file");
+        arma::Mat<size_t> output = trans(assignments);
+        data::Save(outputFile, output);
+      }
+      else
+      {
+        // Convert the assignments to doubles.
+        arma::vec converted(assignments.n_elem);
+        for (size_t i = 0; i < assignments.n_elem; i++)
+          converted(i) = (double) assignments(i);
+
+        dataset.insert_rows(dataset.n_rows, trans(converted));
+
+        // Now save, in the different file.
+        string outputFile = CLI::GetParam<string>("output_file");
+        data::Save(outputFile, dataset);
+      }
     }
   }
+  else
+  {
+    // Just save the centroids.
+    Timer::Start("clustering");
+    kmeans.Cluster(dataset, clusters, centroids, initialCentroidGuess);
+    Timer::Stop("clustering");
+  }
 
   // Should we write the centroids to a file?
   if (CLI::HasParam("centroid_file"))
     data::Save(CLI::GetParam<std::string>("centroid_file"), centroids);
 }
-
diff --git a/src/mlpack/methods/kmeans/naive_kmeans.hpp b/src/mlpack/methods/kmeans/naive_kmeans.hpp
index cf4eda2..be6f637 100644
--- a/src/mlpack/methods/kmeans/naive_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/naive_kmeans.hpp
@@ -40,15 +40,20 @@ class NaiveKMeans
    * @param centroids Current cluster centroids.
    * @param newCentroids New cluster centroids.
    */
-  void Iterate(const arma::mat& centroids,
-               arma::mat& newCentroids,
-               arma::Col<size_t>& counts);
+  double Iterate(const arma::mat& centroids,
+                 arma::mat& newCentroids,
+                 arma::Col<size_t>& counts);
+
+  size_t DistanceCalculations() const { return distanceCalculations; }
 
  private:
   //! The dataset.
   const MatType& dataset;
   //! The instantiated metric.
   MetricType& metric;
+
+  //! Number of distance calculations.
+  size_t distanceCalculations;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp b/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
index 92d26b5..80c5569 100644
--- a/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/naive_kmeans_impl.hpp
@@ -19,14 +19,15 @@ template<typename MetricType, typename MatType>
 NaiveKMeans<MetricType, MatType>::NaiveKMeans(const MatType& dataset,
                                               MetricType& metric) :
     dataset(dataset),
-    metric(metric)
+    metric(metric),
+    distanceCalculations(0)
 { /* Nothing to do. */ }
 
 // Run a single iteration.
 template<typename MetricType, typename MatType>
-void NaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
-                                               arma::mat& newCentroids,
-                                               arma::Col<size_t>& counts)
+double NaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
+                                                 arma::mat& newCentroids,
+                                                 arma::Col<size_t>& counts)
 {
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
@@ -62,6 +63,18 @@ void NaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
       newCentroids.col(i) /= counts(i);
     else
       newCentroids.col(i).fill(DBL_MAX); // Invalid value.
+
+  distanceCalculations += centroids.n_cols * dataset.n_cols;
+
+  // Calculate cluster distortion for this iteration.
+  double cNorm = 0.0;
+  for (size_t i = 0; i < centroids.n_cols; ++i)
+  {
+    const double dist = std::pow(
+        metric.Evaluate(centroids.col(i), newCentroids.col(i)), 2.0);
+    cNorm += std::pow(dist, 2.0);
+  }
+  return sqrt(cNorm);
 }
 
 } // namespace kmeans



More information about the mlpack-git mailing list