[mlpack-svn] r10398 - mlpack/trunk/src/mlpack/methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Nov 25 01:57:13 EST 2011


Author: rcurtin
Date: 2011-11-25 01:57:12 -0500 (Fri, 25 Nov 2011)
New Revision: 10398

Modified:
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
Log:
Make KMeans into an actual class.  It now supports "overclustering".


Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp	2011-11-24 21:53:08 UTC (rev 10397)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp	2011-11-25 06:57:12 UTC (rev 10398)
@@ -1,6 +1,7 @@
 /**
  * @file kmeans.cpp
  * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Ryan Curtin
  *
  * Implementation for the K-means method for getting an initial point.
  */
@@ -11,32 +12,76 @@
 namespace mlpack {
 namespace kmeans {
 
-void KMeans(const arma::mat& data,
-            const size_t value_of_k,
-            std::vector<arma::vec>& means,
-            std::vector<arma::mat>& covars,
-            arma::vec& weights) {
+/**
+ * Construct the K-Means object.
+ */
+KMeans::KMeans(const double overclusteringFactor,
+               const bool allowEmptyClusters,
+               const size_t maxIterations) :
+    allowEmptyClusters(allowEmptyClusters),
+    maxIterations(maxIterations)
+{
+  // Validate overclustering factor.
+  if (overclusteringFactor < 1.0)
+  {
+    Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
+        << overclusteringFactor << " given). Setting factor to 1.0.\n";
+    this->overclusteringFactor = 1.0;
+  }
+  else
+  {
+    this->overclusteringFactor = overclusteringFactor;
+  }
+}
+
+/**
+ * Perform K-Means clustering on the data, returning a list of cluster
+ * assignments.
+ */
+void KMeans::Cluster(const arma::mat& data,
+                     const size_t clusters,
+                     arma::Col<size_t>& assignments) const
+{
   // Make sure we have more points than clusters.
-  if (value_of_k > data.n_cols)
-    Log::Warn << "k-means: more clusters requested than points given.  Empty"
-        << " clusters may result." << std::endl;
+  if (clusters > data.n_cols)
+  {
+    if (allowEmptyClusters)
+      Log::Warn << "KMeans::Cluster(): more clusters requested than points "
+          << "given.  Empty clusters may result." << std::endl;
+    else
+      Log::Fatal << "KMeans::Cluster(): more clusters requested than points "
+          << "given, and empty clusters not allowed.  Terminating.\n";
+  }
 
-  // Assignment of cluster of each point.
-  arma::Col<size_t> assignments(data.n_cols); // Col used so we have shuffle().
+  // Make sure our overclustering factor is valid.
+  size_t actualClusters = size_t(overclusteringFactor * clusters);
+  if (actualClusters > data.n_cols)
+  {
+    Log::Warn << "KMeans::Cluster(): overclustering factor is too large.  No "
+        << "overclustering will be done." << std::endl;
+    actualClusters = clusters;
+  }
+
+  // Now, the initial assignments.  First determine if they are necessary.
+  if (assignments.n_elem != data.n_cols)
+  {
+    // No guesses were given.  Generate random assignments.  Each cluster will
+    // have the same number of points.
+    assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
+        actualClusters - 1, data.n_cols));
+  }
+
   // Centroids of each cluster.  Each column corresponds to a centroid.
-  arma::mat centroids(data.n_rows, value_of_k);
+  arma::mat centroids(data.n_rows, actualClusters);
   // Counts of points in each cluster.
-  arma::Col<size_t> counts(value_of_k);
+  arma::Col<size_t> counts(actualClusters);
 
-  // First we must randomly partition the dataset.
-  assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
-      value_of_k - 1, data.n_cols));
-
   // Set counts correctly.
-  for (size_t i = 0; i < value_of_k; i++)
+  for (size_t i = 0; i < actualClusters; i++)
     counts[i] = accu(assignments == i);
 
-  size_t changed_assignments = 0;
+  size_t changedAssignments = 0;
+  size_t iteration = 0;
   do
   {
     // Update step.
@@ -46,123 +91,224 @@
     for (size_t i = 0; i < data.n_cols; i++)
       centroids.col(assignments[i]) += data.col(i);
 
-    for (size_t i = 0; i < value_of_k; i++)
+    for (size_t i = 0; i < actualClusters; i++)
       centroids.col(i) /= counts[i];
 
     // Assignment step.
     // Find the closest centroid to each point.  We will keep track of how many
     // assignments change.  When no assignments change, we are done.
-    changed_assignments = 0;
+    changedAssignments = 0;
     for (size_t i = 0; i < data.n_cols; i++)
     {
       // Find the closest centroid to this point.
-      double min_distance = std::numeric_limits<double>::infinity();
-      size_t closest_cluster = value_of_k; // Invalid value.
+      double minDistance = std::numeric_limits<double>::infinity();
+      size_t closestCluster = actualClusters; // Invalid value.
 
-      for (size_t j = 0; j < value_of_k; j++)
+      for (size_t j = 0; j < actualClusters; j++)
       {
         double distance = metric::SquaredEuclideanDistance::Evaluate(
             data.unsafe_col(i), centroids.unsafe_col(j));
 
-        if (distance < min_distance)
+        if (distance < minDistance)
         {
-          min_distance = distance;
-          closest_cluster = j;
+          minDistance = distance;
+          closestCluster = j;
         }
       }
 
       // Reassign this point to the closest cluster.
-      if (assignments[i] != closest_cluster)
+      if (assignments[i] != closestCluster)
       {
         // Update counts.
         counts[assignments[i]]--;
-        counts[closest_cluster]++;
+        counts[closestCluster]++;
         // Update assignment.
-        assignments[i] = closest_cluster;
-        changed_assignments++;
+        assignments[i] = closestCluster;
+        changedAssignments++;
       }
     }
 
-    // Keep-bad-things-from-happening step.
-    // Ensure that no cluster is empty, and if so, take corrective action.
-    for (size_t i = 0; i < value_of_k; i++)
+    // If we are not allowing empty clusters, then check that all of our
+    // clusters have points.
+    if (!allowEmptyClusters)
     {
-      if (counts[i] == 0)
+      for (size_t i = 0; i < actualClusters; i++)
       {
-        // Strategy: take the furthest point from the cluster with highest
-        // variance.  So, we need the variance of each cluster.
-        arma::vec variances;
-        variances.zeros(value_of_k);
-        for (size_t j = 0; j < data.n_cols; j++)
-          variances[assignments[j]] += var(data.col(j));
+        if (counts[i] == 0)
+        {
+          // Strategy: take the furthest point from the cluster with highest
+          // variance.  So, we need the variance of each cluster.
+          arma::vec variances;
+          variances.zeros(actualClusters);
+          for (size_t j = 0; j < data.n_cols; j++)
+            variances[assignments[j]] += var(data.col(j));
 
-        size_t cluster;
-        double max_var = 0;
-        for (size_t j = 0; j < value_of_k; j++)
-        {
-          if (variances[j] > max_var)
+          size_t cluster;
+          double maxVar = 0;
+          for (size_t j = 0; j < actualClusters; j++)
           {
-            cluster = j;
-            max_var = variances[j];
+            if (variances[j] > maxVar)
+            {
+              cluster = j;
+              maxVar = variances[j];
+            }
           }
-        }
 
-        // Now find the furthest point.
-        size_t point = data.n_cols; // Invalid.
-        double distance = 0;
-        for (size_t j = 0; j < data.n_cols; j++)
-        {
-          if (assignments[j] == cluster)
+          // Now find the furthest point.
+          size_t point = data.n_cols; // Invalid.
+          double distance = 0;
+          for (size_t j = 0; j < data.n_cols; j++)
           {
-            double d = metric::SquaredEuclideanDistance::Evaluate(
-                data.unsafe_col(j), centroids.unsafe_col(cluster));
+            if (assignments[j] == cluster)
+            {
+              double d = metric::SquaredEuclideanDistance::Evaluate(
+                  data.unsafe_col(j), centroids.unsafe_col(cluster));
 
-            if (d >= distance)
-            {
-              distance = d;
-              point = j;
+              if (d >= distance)
+              {
+                distance = d;
+                point = j;
+              }
             }
           }
+
+          // Take that point and add it to the empty cluster.
+          counts[cluster]--;
+          counts[i]++;
+          assignments[point] = i;
+          changedAssignments++;
         }
-
-        // Take that point and add it to the empty cluster.
-        counts[cluster]--;
-        counts[i]++;
-        assignments[point] = i;
-        changed_assignments++;
       }
     }
+    iteration++;
 
-  } while (changed_assignments > 0);
+  } while (changedAssignments > 0 && iteration != maxIterations);
 
-  // Now, with the centroids final, we need to find the covariance matrix of
-  // each cluster and then the a priori weight.  We also need to assign the
-  // means to be the centroids.  First, we must make sure the size of the
-  // vectors is correct.
-  means.resize(value_of_k);
-  covars.resize(value_of_k);
-  weights.set_size(value_of_k);
-  for (size_t i = 0; i < value_of_k; i++)
+  // If we have overclustered, we need to merge the nearest clusters.
+  if (actualClusters != clusters)
   {
-    // Assign mean.
-    means[i] = centroids.col(i);
+    // Generate a list of all the clusters' distances from each other.  This
+    // list will become mangled and unused as the number of clusters decreases.
+    size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
+    size_t clustersLeft = actualClusters;
+    arma::vec distances(numDistances);
+    arma::Col<size_t> firstCluster(numDistances);
+    arma::Col<size_t> secondCluster(numDistances);
 
-    // Calculate covariance.
-    arma::mat data_subset(data.n_rows, accu(assignments == i));
-    size_t position = 0;
-    for (size_t j = 0; j < data.n_cols; j++)
+    // Keep the mappings of clusters that we are changing.
+    arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
+        actualClusters - 1, actualClusters);
+
+    size_t i = 0;
+    for (size_t first = 0; first < actualClusters; first++)
     {
-      if (assignments[j] == i)
+      for (size_t second = first + 1; second < actualClusters; second++)
       {
-        data_subset.col(position) = data.col(j);
-        position++;
+        distances(i) = metric::SquaredEuclideanDistance::Evaluate(
+            centroids.unsafe_col(first), centroids.unsafe_col(second));
+        firstCluster(i) = first;
+        secondCluster(i) = second;
+        i++;
       }
     }
 
-    covars[i] = ccov(data_subset);
+    while (clustersLeft != clusters)
+    {
+      arma::u32 minIndex;
+      distances.min(minIndex);
 
-    // Assign weight.
-    weights[i] = (double) accu(assignments == i) / (double) data.n_cols;
+      // Now we merge the clusters which that distance belongs to.
+      size_t first = firstCluster(minIndex);
+      size_t second = secondCluster(minIndex);
+      for (size_t j = 0; j < assignments.n_elem; j++)
+        if (assignments(j) == second)
+          assignments(j) = first;
+
+      // Now merge the centroids.
+      centroids.col(first) *= counts[first];
+      centroids.col(first) += (counts[second] * centroids.col(second));
+      centroids.col(first) /= (counts[first] + counts[second]);
+
+      // Update the counts.
+      counts[first] += counts[second];
+      counts[second] = 0;
+
+      // Now update all the relevant distances.
+      // First the distances where either cluster is the second cluster.
+      for (size_t cluster = 0; cluster < second; cluster++)
+      {
+        // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
+        // m is the cluster we are trying to offset to.
+        size_t offset = (((actualClusters - 1) * cluster)
+            + (cluster - pow(cluster, 2)) / 2) - 1;
+
+        // See if we need to update the distance from this cluster to the first
+        // cluster.
+        if (cluster < first)
+        {
+          // Make sure it isn't already DBL_MAX.
+          if (distances(offset + (first - cluster)) != DBL_MAX)
+            distances(offset + (first - cluster)) =
+                metric::SquaredEuclideanDistance::Evaluate(
+                centroids.unsafe_col(first), centroids.unsafe_col(cluster));
+        }
+
+        distances(offset + (second - cluster)) = DBL_MAX;
+      }
+
+      // Now the distances where the first cluster is the first cluster.
+      size_t offset = (((actualClusters - 1) * first)
+          + (first - pow(first, 2)) / 2) - 1;
+      for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
+      {
+        // Make sure it isn't already DBL_MAX.
+        if (distances(offset + (cluster - first)) != DBL_MAX)
+        {
+          distances(offset + (cluster - first)) =
+              metric::SquaredEuclideanDistance::Evaluate(
+              centroids.unsafe_col(first), centroids.unsafe_col(cluster));
+        }
+      }
+
+      // Max the distance between the first and second clusters.
+      distances(offset + (second - first)) = DBL_MAX;
+
+      // Now max the distances for the second cluster (which no longer has
+      // anything in it).
+      offset = (((actualClusters - 1) * second) + (second - pow(second, 2)) / 2)
+          - 1;
+      for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
+        distances(offset + (cluster - second)) = DBL_MAX;
+
+      clustersLeft--;
+
+      // Update the cluster mappings.
+      mappings(second) = first;
+      // Also update any mappings that were pointed at the previous cluster.
+      for (size_t cluster = 0; cluster < actualClusters; cluster++)
+        if (mappings(cluster) == second)
+          mappings(cluster) = first;
+    }
+
+    // Now remap the mappings down to the smallest possible numbers.
+    // Could this process be sped up?
+    arma::Col<size_t> remappings(actualClusters);
+    remappings.fill(actualClusters);
+    size_t remap = 0; // Counter variable.
+    for (size_t cluster = 0; cluster < actualClusters; cluster++)
+    {
+      // If the mapping of the current cluster has not been assigned a value
+      // yet, we will assign it a cluster number.
+      if (remappings(mappings(cluster)) == actualClusters)
+      {
+        remappings(mappings(cluster)) = remap;
+        remap++;
+      }
+    }
+
+    // Fix the assignments using the mappings we created.
+    for (size_t j = 0; j < assignments.n_elem; j++)
+      assignments(j) = remappings(mappings(assignments(j)));
   }
 }
 

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2011-11-24 21:53:08 UTC (rev 10397)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2011-11-25 06:57:12 UTC (rev 10398)
@@ -13,19 +13,105 @@
 namespace kmeans {
 
 /**
- * This function computes the k-means of the data and stores the calculated
- * means and covariances in the std::vector of vectors and matrices passed to
- * it.  It sets the weights uniformly.
- *
- * This function is used to obtain a starting point for the optimization.
+ * This class implements K-Means clustering.
  */
-void KMeans(const arma::mat& data,
-            const size_t value_of_k,
-            std::vector<arma::vec>& means,
-            std::vector<arma::mat>& covars,
-            arma::vec& weights);
+class KMeans
+{
+ public:
+  /**
+   * Create a K-Means object and (optionally) set the parameters which K-Means
+   * will be run with.  This implementation allows a few strategies to improve
+   * the performance of K-Means, including "overclustering" and disallowing
+   * empty clusters.
+   *
+   * The overclustering factor controls how many clusters are
+   * actually found; for instance, with an overclustering factor of 4, if
+   * K-Means is run to find 3 clusters, it will actually find 12, then merge the
+   * nearest clusters until only 3 are left.
+   *
+   * @param overclusteringFactor Factor controlling how many extra clusters are
+   *     found and then merged to get the desired number of clusters.
+   * @param allowEmptyClusters If false, then clustering will fail instead of
+   *     returning an empty cluster.
+   * @param maxIterations Maximum number of iterations allowed before giving up
+   *     (0 is valid, but the algorithm may never terminate).
+   */
+  KMeans(const double overclusteringFactor = 4.0,
+         const bool allowEmptyClusters = false,
+         const size_t maxIterations = 1000);
 
-}; // namespace gmm
+  /**
+   * Perform K-Means clustering on the data, returning a list of cluster
+   * assignments.  Optionally, the vector of assignments can be set to an
+   * initial guess of the cluster assignments; to do this, the number of
+   * elements in the list of assignments must be equal to the number of points
+   * (columns) in the dataset.
+   *
+   * @param data Dataset to cluster.
+   * @param clusters Number of clusters to compute.
+   * @param assignments Vector to store cluster assignments in.  Can contain an
+   *     initial guess at cluster assignments.
+   */
+  void Cluster(const arma::mat& data,
+               const size_t clusters,
+               arma::Col<size_t>& assignments) const;
+
+  /**
+   * Return the overclustering factor.
+   */
+  double OverclusteringFactor() const { return overclusteringFactor; }
+
+  /**
+   * Set the overclustering factor.
+   */
+  void OverclusteringFactor(const double overclusteringFactor)
+  {
+    if (overclusteringFactor < 1.0)
+    {
+      Log::Warn << "KMeans::OverclusteringFactor(): invalid value (<= 1.0) "
+          "ignored." << std::endl;
+      return;
+    }
+
+    this->overclusteringFactor = overclusteringFactor;
+  }
+
+  /**
+   * Return whether or not empty clusters are allowed.
+   */
+  bool AllowEmptyClusters() const { return allowEmptyClusters; }
+
+  /**
+   * Set whether or not empty clusters are allowed.
+   */
+  void AllowEmptyClusters(bool allowEmptyClusters)
+  {
+    this->allowEmptyClusters = allowEmptyClusters;
+  }
+
+  /**
+   * Get the maximum number of iterations.
+   */
+  size_t MaxIterations() const { return maxIterations; }
+
+  /**
+   * Set the maximum number of iterations.
+   */
+  void MaxIterations(size_t maxIterations)
+  {
+    this->maxIterations = maxIterations;
+  }
+
+ private:
+  //! Factor controlling how many clusters are actually found.
+  double overclusteringFactor;
+  //! Whether or not to allow empty clusters to be returned.
+  bool allowEmptyClusters;
+  //! Maximum number of iterations before giving up.
+  size_t maxIterations;
+};
+
+}; // namespace kmeans
 }; // namespace mlpack
 
 #endif // __MLPACK_METHODS_MOG_KMEANS_HPP




More information about the mlpack-svn mailing list