[mlpack-svn] r10421 - mlpack/trunk/src/mlpack/methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Nov 26 22:22:06 EST 2011


Author: rcurtin
Date: 2011-11-26 22:22:05 -0500 (Sat, 26 Nov 2011)
New Revision: 10421

Added:
   mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.cpp
   mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/random_partition.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
Log:
Abstract some features of KMeans into template policy classes.


Modified: mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt	2011-11-26 23:21:32 UTC (rev 10420)
+++ mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt	2011-11-27 03:22:05 UTC (rev 10421)
@@ -3,8 +3,12 @@
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
+  allow_empty_clusters.hpp
   kmeans.hpp
-  kmeans.cpp
+  kmeans_impl.hpp
+  max_variance_new_cluster.hpp
+  max_variance_new_cluster.cpp
+  random_partition.hpp
 )
 
 # Add directory name to sources.

Added: mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/kmeans/allow_empty_clusters.hpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -0,0 +1,52 @@
+/**
+ * @file allow_empty_clusters.hpp
+ * @author Ryan Curtin
+ *
+ * This very simple policy is used when K-Means is allowed to return empty
+ * clusters.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
+#define __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * Policy which allows K-Means to create empty clusters without any error being
+ * reported.
+ */
+class AllowEmptyClusters
+{
+ public:
+  //! Default constructor required by EmptyClusterPolicy policy.
+  AllowEmptyClusters() { }
+
+  /**
+   * This function does nothing.  It is called by K-Means when K-Means detects
+   * an empty cluster.
+   *
+   * @param data Dataset on which clustering is being performed.
+   * @param emptyCluster Index of cluster which is empty.
+   * @param centroids Centroids of each cluster (one per column).
+   * @param clusterCounts Number of points in each cluster.
+   * @param assignments Cluster assignments of each point.
+   *
+   * @return Number of points changed (0).
+   */
+  static size_t EmptyCluster(const arma::mat& data,
+                             const size_t emptyCluster,
+                             const arma::mat& centroids,
+                             arma::Col<size_t>& clusterCounts,
+                             arma::Col<size_t>& assignments)
+  {
+    // Empty clusters are okay!  Do nothing.
+    return 0;
+  }
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+#endif

Deleted: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp	2011-11-26 23:21:32 UTC (rev 10420)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -1,316 +0,0 @@
-/**
- * @file kmeans.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @author Ryan Curtin
- *
- * Implementation for the K-means method for getting an initial point.
- */
-#include "kmeans.hpp"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * Construct the K-Means object.
- */
-KMeans::KMeans(const double overclusteringFactor,
-               const bool allowEmptyClusters,
-               const size_t maxIterations) :
-    allowEmptyClusters(allowEmptyClusters),
-    maxIterations(maxIterations)
-{
-  // Validate overclustering factor.
-  if (overclusteringFactor < 1.0)
-  {
-    Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
-        << overclusteringFactor << " given). Setting factor to 1.0.\n";
-    this->overclusteringFactor = 1.0;
-  }
-  else
-  {
-    this->overclusteringFactor = overclusteringFactor;
-  }
-}
-
-/**
- * Perform K-Means clustering on the data, returning a list of cluster
- * assignments.
- */
-void KMeans::Cluster(const arma::mat& data,
-                     const size_t clusters,
-                     arma::Col<size_t>& assignments) const
-{
-  // Make sure we have more points than clusters.
-  if (clusters > data.n_cols)
-  {
-    if (allowEmptyClusters)
-      Log::Warn << "KMeans::Cluster(): more clusters requested than points "
-          << "given.  Empty clusters may result." << std::endl;
-    else
-      Log::Fatal << "KMeans::Cluster(): more clusters requested than points "
-          << "given, and empty clusters not allowed.  Terminating.\n";
-  }
-
-  // Make sure our overclustering factor is valid.
-  size_t actualClusters = size_t(overclusteringFactor * clusters);
-  if (actualClusters > data.n_cols)
-  {
-    Log::Warn << "KMeans::Cluster(): overclustering factor is too large.  No "
-        << "overclustering will be done." << std::endl;
-    actualClusters = clusters;
-  }
-
-  // Now, the initial assignments.  First determine if they are necessary.
-  if (assignments.n_elem != data.n_cols)
-  {
-    // No guesses were given.  Generate random assignments.  Each cluster will
-    // have the same number of points.
-    assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
-        actualClusters - 1, data.n_cols));
-  }
-
-  // Centroids of each cluster.  Each column corresponds to a centroid.
-  arma::mat centroids(data.n_rows, actualClusters);
-  // Counts of points in each cluster.
-  arma::Col<size_t> counts(actualClusters);
-
-  // Set counts correctly.
-  for (size_t i = 0; i < actualClusters; i++)
-    counts[i] = accu(assignments == i);
-
-  size_t changedAssignments = 0;
-  size_t iteration = 0;
-  do
-  {
-    // Update step.
-    // Calculate centroids based on given assignments.
-    centroids.zeros();
-
-    for (size_t i = 0; i < data.n_cols; i++)
-      centroids.col(assignments[i]) += data.col(i);
-
-    for (size_t i = 0; i < actualClusters; i++)
-      centroids.col(i) /= counts[i];
-
-    // Assignment step.
-    // Find the closest centroid to each point.  We will keep track of how many
-    // assignments change.  When no assignments change, we are done.
-    changedAssignments = 0;
-    for (size_t i = 0; i < data.n_cols; i++)
-    {
-      // Find the closest centroid to this point.
-      double minDistance = std::numeric_limits<double>::infinity();
-      size_t closestCluster = actualClusters; // Invalid value.
-
-      for (size_t j = 0; j < actualClusters; j++)
-      {
-        double distance = metric::SquaredEuclideanDistance::Evaluate(
-            data.unsafe_col(i), centroids.unsafe_col(j));
-
-        if (distance < minDistance)
-        {
-          minDistance = distance;
-          closestCluster = j;
-        }
-      }
-
-      // Reassign this point to the closest cluster.
-      if (assignments[i] != closestCluster)
-      {
-        // Update counts.
-        counts[assignments[i]]--;
-        counts[closestCluster]++;
-        // Update assignment.
-        assignments[i] = closestCluster;
-        changedAssignments++;
-      }
-    }
-
-    // If we are not allowing empty clusters, then check that all of our
-    // clusters have points.
-    if (!allowEmptyClusters)
-    {
-      for (size_t i = 0; i < actualClusters; i++)
-      {
-        if (counts[i] == 0)
-        {
-          // Strategy: take the furthest point from the cluster with highest
-          // variance.  So, we need the variance of each cluster.
-          arma::vec variances;
-          variances.zeros(actualClusters);
-          for (size_t j = 0; j < data.n_cols; j++)
-            variances[assignments[j]] += var(data.col(j));
-
-          size_t cluster;
-          double maxVar = 0;
-          for (size_t j = 0; j < actualClusters; j++)
-          {
-            if (variances[j] > maxVar)
-            {
-              cluster = j;
-              maxVar = variances[j];
-            }
-          }
-
-          // Now find the furthest point.
-          size_t point = data.n_cols; // Invalid.
-          double distance = 0;
-          for (size_t j = 0; j < data.n_cols; j++)
-          {
-            if (assignments[j] == cluster)
-            {
-              double d = metric::SquaredEuclideanDistance::Evaluate(
-                  data.unsafe_col(j), centroids.unsafe_col(cluster));
-
-              if (d >= distance)
-              {
-                distance = d;
-                point = j;
-              }
-            }
-          }
-
-          // Take that point and add it to the empty cluster.
-          counts[cluster]--;
-          counts[i]++;
-          assignments[point] = i;
-          changedAssignments++;
-        }
-      }
-    }
-    iteration++;
-
-  } while (changedAssignments > 0 && iteration != maxIterations);
-
-  // If we have overclustered, we need to merge the nearest clusters.
-  if (actualClusters != clusters)
-  {
-    // Generate a list of all the clusters' distances from each other.  This
-    // list will become mangled and unused as the number of clusters decreases.
-    size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
-    size_t clustersLeft = actualClusters;
-    arma::vec distances(numDistances);
-    arma::Col<size_t> firstCluster(numDistances);
-    arma::Col<size_t> secondCluster(numDistances);
-
-    // Keep the mappings of clusters that we are changing.
-    arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
-        actualClusters - 1, actualClusters);
-
-    size_t i = 0;
-    for (size_t first = 0; first < actualClusters; first++)
-    {
-      for (size_t second = first + 1; second < actualClusters; second++)
-      {
-        distances(i) = metric::SquaredEuclideanDistance::Evaluate(
-            centroids.unsafe_col(first), centroids.unsafe_col(second));
-        firstCluster(i) = first;
-        secondCluster(i) = second;
-        i++;
-      }
-    }
-
-    while (clustersLeft != clusters)
-    {
-      arma::u32 minIndex;
-      distances.min(minIndex);
-
-      // Now we merge the clusters which that distance belongs to.
-      size_t first = firstCluster(minIndex);
-      size_t second = secondCluster(minIndex);
-      for (size_t j = 0; j < assignments.n_elem; j++)
-        if (assignments(j) == second)
-          assignments(j) = first;
-
-      // Now merge the centroids.
-      centroids.col(first) *= counts[first];
-      centroids.col(first) += (counts[second] * centroids.col(second));
-      centroids.col(first) /= (counts[first] + counts[second]);
-
-      // Update the counts.
-      counts[first] += counts[second];
-      counts[second] = 0;
-
-      // Now update all the relevant distances.
-      // First the distances where either cluster is the second cluster.
-      for (size_t cluster = 0; cluster < second; cluster++)
-      {
-        // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
-        // m is the cluster we are trying to offset to.
-        size_t offset = (((actualClusters - 1) * cluster)
-            + (cluster - pow(cluster, 2)) / 2) - 1;
-
-        // See if we need to update the distance from this cluster to the first
-        // cluster.
-        if (cluster < first)
-        {
-          // Make sure it isn't already DBL_MAX.
-          if (distances(offset + (first - cluster)) != DBL_MAX)
-            distances(offset + (first - cluster)) =
-                metric::SquaredEuclideanDistance::Evaluate(
-                centroids.unsafe_col(first), centroids.unsafe_col(cluster));
-        }
-
-        distances(offset + (second - cluster)) = DBL_MAX;
-      }
-
-      // Now the distances where the first cluster is the first cluster.
-      size_t offset = (((actualClusters - 1) * first)
-          + (first - pow(first, 2)) / 2) - 1;
-      for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
-      {
-        // Make sure it isn't already DBL_MAX.
-        if (distances(offset + (cluster - first)) != DBL_MAX)
-        {
-          distances(offset + (cluster - first)) =
-              metric::SquaredEuclideanDistance::Evaluate(
-              centroids.unsafe_col(first), centroids.unsafe_col(cluster));
-        }
-      }
-
-      // Max the distance between the first and second clusters.
-      distances(offset + (second - first)) = DBL_MAX;
-
-      // Now max the distances for the second cluster (which no longer has
-      // anything in it).
-      offset = (((actualClusters - 1) * second) + (second - pow(second, 2)) / 2)
-          - 1;
-      for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
-        distances(offset + (cluster - second)) = DBL_MAX;
-
-      clustersLeft--;
-
-      // Update the cluster mappings.
-      mappings(second) = first;
-      // Also update any mappings that were pointed at the previous cluster.
-      for (size_t cluster = 0; cluster < actualClusters; cluster++)
-        if (mappings(cluster) == second)
-          mappings(cluster) = first;
-    }
-
-    // Now remap the mappings down to the smallest possible numbers.
-    // Could this process be sped up?
-    arma::Col<size_t> remappings(actualClusters);
-    remappings.fill(actualClusters);
-    size_t remap = 0; // Counter variable.
-    for (size_t cluster = 0; cluster < actualClusters; cluster++)
-    {
-      // If the mapping of the current cluster has not been assigned a value
-      // yet, we will assign it a cluster number.
-      if (remappings(mappings(cluster)) == actualClusters)
-      {
-        remappings(mappings(cluster)) = remap;
-        remap++;
-      }
-    }
-
-    // Fix the assignments using the mappings we created.
-    for (size_t j = 0; j < assignments.n_elem; j++)
-      assignments(j) = remappings(mappings(assignments(j)));
-  }
-}
-
-}; // namespace gmm
-}; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2011-11-26 23:21:32 UTC (rev 10420)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -9,12 +9,35 @@
 
 #include <mlpack/core.hpp>
 
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "random_partition.hpp"
+#include "max_variance_new_cluster.hpp"
+
 namespace mlpack {
 namespace kmeans {
 
 /**
- * This class implements K-Means clustering.
+ * This class implements K-Means clustering.  This implementation supports
+ * overclustering, which means that more clusters than are requested will be
+ * found; then, those clusters will be merged together to produce the desired
+ * number of clusters.
+ *
+ * Two template parameters can (optionally) be supplied: the policy for how to
+ * find the initial partition of the data, and the actions to be taken when an
+ * empty cluster is encountered, as well as the distance metric to be used.
+ *
+ * @tparam DistanceMetric The distance metric to use for this KMeans; see
+ *     metric::LMetric for an example.
+ * @tparam InitialPartitionPolicy Initial partitioning policy; must implement a
+ *     default constructor and 'void Cluster(const arma::mat&, const size_t,
+ *     arma::Col<size_t>&)'.  @see RandomPartition for an example.
+ * @tparam EmptyClusterPolicy Policy for what to do on an empty cluster; must
+ *     implement a default constructor and 'void EmptyCluster(const arma::mat&,
+ *     arma::Col<size_t&)'.  @see AllowEmptyClusters and MaxVarianceNewCluster.
  */
+template<typename DistanceMetric = metric::SquaredEuclideanDistance,
+         typename InitialPartitionPolicy = RandomPartition,
+         typename EmptyClusterPolicy = MaxVarianceNewCluster>
 class KMeans
 {
  public:
@@ -29,17 +52,24 @@
    * K-Means is run to find 3 clusters, it will actually find 12, then merge the
    * nearest clusters until only 3 are left.
    *
+   * @param maxIterations Maximum number of iterations allowed before giving up
+   *     (0 is valid, but the algorithm may never terminate).
    * @param overclusteringFactor Factor controlling how many extra clusters are
    *     found and then merged to get the desired number of clusters.
-   * @param allowEmptyClusters If false, then clustering will fail instead of
-   *     returning an empty cluster.
-   * @param maxIterations Maximum number of iterations allowed before giving up
-   *     (0 is valid, but the algorithm may never terminate).
+   * @param metric Optional DistanceMetric object; for when the metric has state
+   *     it needs to store.
+   * @param partitioner Optional InitialPartitionPolicy object; for when a
+   *     specially initialized partitioning policy is required.
+   * @param emptyClusterAction Optional EmptyClusterPolicy object; for when a
+   *     specially initialized empty cluster policy is required.
    */
-  KMeans(const double overclusteringFactor = 4.0,
-         const bool allowEmptyClusters = false,
-         const size_t maxIterations = 1000);
+  KMeans(const size_t maxIterations = 1000,
+         const double overclusteringFactor = 1.0,
+         const DistanceMetric metric = DistanceMetric(),
+         const InitialPartitionPolicy partitioner = InitialPartitionPolicy(),
+         const EmptyClusterPolicy emptyClusterAction = EmptyClusterPolicy());
 
+
   /**
    * Perform K-Means clustering on the data, returning a list of cluster
    * assignments.  Optionally, the vector of assignments can be set to an
@@ -77,19 +107,6 @@
   }
 
   /**
-   * Return whether or not empty clusters are allowed.
-   */
-  bool AllowEmptyClusters() const { return allowEmptyClusters; }
-
-  /**
-   * Set whether or not empty clusters are allowed.
-   */
-  void AllowEmptyClusters(bool allowEmptyClusters)
-  {
-    this->allowEmptyClusters = allowEmptyClusters;
-  }
-
-  /**
    * Get the maximum number of iterations.
    */
   size_t MaxIterations() const { return maxIterations; }
@@ -97,21 +114,46 @@
   /**
    * Set the maximum number of iterations.
    */
-  void MaxIterations(size_t maxIterations)
+  void MaxIterations(const size_t maxIterations)
   {
     this->maxIterations = maxIterations;
   }
 
+  //! Get the distance metric.
+  const DistanceMetric& Metric() const { return metric; }
+  //! Modify the distance metric.
+  DistanceMetric& Metric() { return metric; }
+
+  //! Get the initial partitioning policy.
+  const InitialPartitionPolicy& Partitioner() const { return partitioner; }
+  //! Modify the initial partitioning policy.
+  InitialPartitionPolicy& Partitioner() { return partitioner; }
+
+  //! Get the empty cluster policy.
+  const EmptyClusterPolicy& EmptyClusterAction() const
+  {
+    return emptyClusterAction;
+  }
+  //! Modify the empty cluster policy.
+  EmptyClusterPolicy& EmptyClusterAction() { return emptyClusterAction; }
+
  private:
   //! Factor controlling how many clusters are actually found.
   double overclusteringFactor;
-  //! Whether or not to allow empty clusters to be returned.
-  bool allowEmptyClusters;
   //! Maximum number of iterations before giving up.
   size_t maxIterations;
+  //! Instantiated distance metric.
+  DistanceMetric metric;
+  //! Instantiated initial partitioning policy.
+  InitialPartitionPolicy partitioner;
+  //! Instantiated empty cluster policy.
+  EmptyClusterPolicy emptyClusterAction;
 };
 
 }; // namespace kmeans
 }; // namespace mlpack
 
+// Include implementation.
+#include "kmeans_impl.hpp"
+
 #endif // __MLPACK_METHODS_MOG_KMEANS_HPP

Copied: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp (from rev 10398, mlpack/trunk/src/mlpack/methods/kmeans/kmeans.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -0,0 +1,281 @@
+/**
+ * @file kmeans_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Ryan Curtin
+ *
+ * Implementation for the K-means method for getting an initial point.
+ */
+#include "kmeans.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * Construct the K-Means object.
+ */
+template<typename DistanceMetric,
+         typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy>
+KMeans<
+    DistanceMetric,
+    InitialPartitionPolicy,
+    EmptyClusterPolicy>::
+KMeans(const size_t maxIterations,
+       const double overclusteringFactor,
+       const DistanceMetric metric,
+       const InitialPartitionPolicy partitioner,
+       const EmptyClusterPolicy emptyClusterAction) :
+    maxIterations(maxIterations),
+    metric(metric),
+    partitioner(partitioner),
+    emptyClusterAction(emptyClusterAction)
+{
+  // Validate overclustering factor.
+  if (overclusteringFactor < 1.0)
+  {
+    Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
+        << overclusteringFactor << " given). Setting factor to 1.0.\n";
+    this->overclusteringFactor = 1.0;
+  }
+  else
+  {
+    this->overclusteringFactor = overclusteringFactor;
+  }
+}
+
+/**
+ * Perform K-Means clustering on the data, returning a list of cluster
+ * assignments.
+ */
+template<typename DistanceMetric,
+         typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy>
+void KMeans<
+    DistanceMetric,
+    InitialPartitionPolicy,
+    EmptyClusterPolicy>::
+Cluster(const arma::mat& data,
+        const size_t clusters,
+        arma::Col<size_t>& assignments) const
+{
+  // Make sure we have more points than clusters.
+  if (clusters > data.n_cols)
+    Log::Warn << "KMeans::Cluster(): more clusters requested than points given."
+        << std::endl;
+
+  // Make sure our overclustering factor is valid.
+  size_t actualClusters = size_t(overclusteringFactor * clusters);
+  if (actualClusters > data.n_cols)
+  {
+    Log::Warn << "KMeans::Cluster(): overclustering factor is too large.  No "
+        << "overclustering will be done." << std::endl;
+    actualClusters = clusters;
+  }
+
+  // Now, the initial assignments.  First determine if they are necessary.
+  if (assignments.n_elem != data.n_cols)
+  {
+    // Use the partitioner to come up with the partition assignments.
+    partitioner.Cluster(data, actualClusters, assignments);
+  }
+
+  // Centroids of each cluster.  Each column corresponds to a centroid.
+  arma::mat centroids(data.n_rows, actualClusters);
+  // Counts of points in each cluster.
+  arma::Col<size_t> counts(actualClusters);
+
+  // Set counts correctly.
+  for (size_t i = 0; i < assignments.n_elem; i++)
+    counts[assignments[i]]++;
+
+  size_t changedAssignments = 0;
+  size_t iteration = 0;
+  do
+  {
+    // Update step.
+    // Calculate centroids based on given assignments.
+    centroids.zeros();
+
+    for (size_t i = 0; i < data.n_cols; i++)
+      centroids.col(assignments[i]) += data.col(i);
+
+    for (size_t i = 0; i < actualClusters; i++)
+      centroids.col(i) /= counts[i];
+
+    // Assignment step.
+    // Find the closest centroid to each point.  We will keep track of how many
+    // assignments change.  When no assignments change, we are done.
+    changedAssignments = 0;
+    for (size_t i = 0; i < data.n_cols; i++)
+    {
+      // Find the closest centroid to this point.
+      double minDistance = std::numeric_limits<double>::infinity();
+      size_t closestCluster = actualClusters; // Invalid value.
+
+      for (size_t j = 0; j < actualClusters; j++)
+      {
+        double distance = metric::SquaredEuclideanDistance::Evaluate(
+            data.unsafe_col(i), centroids.unsafe_col(j));
+
+        if (distance < minDistance)
+        {
+          minDistance = distance;
+          closestCluster = j;
+        }
+      }
+
+      // Reassign this point to the closest cluster.
+      if (assignments[i] != closestCluster)
+      {
+        // Update counts.
+        counts[assignments[i]]--;
+        counts[closestCluster]++;
+        // Update assignment.
+        assignments[i] = closestCluster;
+        changedAssignments++;
+      }
+    }
+
+    // If we are not allowing empty clusters, then check that all of our
+    // clusters have points.
+    for (size_t i = 0; i < actualClusters; i++)
+      if (counts[i] == 0)
+        changedAssignments += emptyClusterAction.EmptyCluster(data, i,
+            centroids, counts, assignments);
+
+    iteration++;
+
+  } while (changedAssignments > 0 && iteration != maxIterations);
+
+  // If we have overclustered, we need to merge the nearest clusters.
+  if (actualClusters != clusters)
+  {
+    // Generate a list of all the clusters' distances from each other.  This
+    // list will become mangled and unused as the number of clusters decreases.
+    size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
+    size_t clustersLeft = actualClusters;
+    arma::vec distances(numDistances);
+    arma::Col<size_t> firstCluster(numDistances);
+    arma::Col<size_t> secondCluster(numDistances);
+
+    // Keep the mappings of clusters that we are changing.
+    arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
+        actualClusters - 1, actualClusters);
+
+    size_t i = 0;
+    for (size_t first = 0; first < actualClusters; first++)
+    {
+      for (size_t second = first + 1; second < actualClusters; second++)
+      {
+        distances(i) = metric::SquaredEuclideanDistance::Evaluate(
+            centroids.unsafe_col(first), centroids.unsafe_col(second));
+        firstCluster(i) = first;
+        secondCluster(i) = second;
+        i++;
+      }
+    }
+
+    while (clustersLeft != clusters)
+    {
+      arma::u32 minIndex;
+      distances.min(minIndex);
+
+      // Now we merge the clusters which that distance belongs to.
+      size_t first = firstCluster(minIndex);
+      size_t second = secondCluster(minIndex);
+      for (size_t j = 0; j < assignments.n_elem; j++)
+        if (assignments(j) == second)
+          assignments(j) = first;
+
+      // Now merge the centroids.
+      centroids.col(first) *= counts[first];
+      centroids.col(first) += (counts[second] * centroids.col(second));
+      centroids.col(first) /= (counts[first] + counts[second]);
+
+      // Update the counts.
+      counts[first] += counts[second];
+      counts[second] = 0;
+
+      // Now update all the relevant distances.
+      // First the distances where either cluster is the second cluster.
+      for (size_t cluster = 0; cluster < second; cluster++)
+      {
+        // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
+        // m is the cluster we are trying to offset to.
+        size_t offset = (((actualClusters - 1) * cluster)
+            + (cluster - pow(cluster, 2)) / 2) - 1;
+
+        // See if we need to update the distance from this cluster to the first
+        // cluster.
+        if (cluster < first)
+        {
+          // Make sure it isn't already DBL_MAX.
+          if (distances(offset + (first - cluster)) != DBL_MAX)
+            distances(offset + (first - cluster)) =
+                metric::SquaredEuclideanDistance::Evaluate(
+                centroids.unsafe_col(first), centroids.unsafe_col(cluster));
+        }
+
+        distances(offset + (second - cluster)) = DBL_MAX;
+      }
+
+      // Now the distances where the first cluster is the first cluster.
+      size_t offset = (((actualClusters - 1) * first)
+          + (first - pow(first, 2)) / 2) - 1;
+      for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
+      {
+        // Make sure it isn't already DBL_MAX.
+        if (distances(offset + (cluster - first)) != DBL_MAX)
+        {
+          distances(offset + (cluster - first)) =
+              metric::SquaredEuclideanDistance::Evaluate(
+              centroids.unsafe_col(first), centroids.unsafe_col(cluster));
+        }
+      }
+
+      // Max the distance between the first and second clusters.
+      distances(offset + (second - first)) = DBL_MAX;
+
+      // Now max the distances for the second cluster (which no longer has
+      // anything in it).
+      offset = (((actualClusters - 1) * second) + (second - pow(second, 2)) / 2)
+          - 1;
+      for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
+        distances(offset + (cluster - second)) = DBL_MAX;
+
+      clustersLeft--;
+
+      // Update the cluster mappings.
+      mappings(second) = first;
+      // Also update any mappings that were pointed at the previous cluster.
+      for (size_t cluster = 0; cluster < actualClusters; cluster++)
+        if (mappings(cluster) == second)
+          mappings(cluster) = first;
+    }
+
+    // Now remap the mappings down to the smallest possible numbers.
+    // Could this process be sped up?
+    arma::Col<size_t> remappings(actualClusters);
+    remappings.fill(actualClusters);
+    size_t remap = 0; // Counter variable.
+    for (size_t cluster = 0; cluster < actualClusters; cluster++)
+    {
+      // If the mapping of the current cluster has not been assigned a value
+      // yet, we will assign it a cluster number.
+      if (remappings(mappings(cluster)) == actualClusters)
+      {
+        remappings(mappings(cluster)) = remap;
+        remap++;
+      }
+    }
+
+    // Fix the assignments using the mappings we created.
+    for (size_t j = 0; j < assignments.n_elem; j++)
+      assignments(j) = remappings(mappings(assignments(j)));
+  }
+}
+
+}; // namespace gmm
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.cpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -0,0 +1,62 @@
+/**
+ * @file max_variance_new_cluster.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of MaxVarianceNewCluster class.
+ */
+#include "max_variance_new_cluster.hpp"
+
+using namespace mlpack;
+using namespace kmeans;
+
+/**
+ * Take action about an empty cluster.
+ */
+size_t MaxVarianceNewCluster::EmptyCluster(const arma::mat& data,
+                                           const size_t emptyCluster,
+                                           const arma::mat& centroids,
+                                           arma::Col<size_t>& clusterCounts,
+                                           arma::Col<size_t>& assignments)
+{
+  // First, we need to find the cluster with maximum variance (by which I mean
+  // the sum of the covariance matrices).
+  arma::vec variances;
+  variances.zeros(clusterCounts.n_elem); // Start with 0.
+
+  // Add the variance of each point's distance away from the cluster.  I think
+  // this is the sensible thing to do.
+  for (size_t i = 0; i < data.n_cols; i++)
+  {
+    arma::vec diff = data.col(i) - centroids.col(assignments[i]);
+    variances[assignments[i]] += var(diff);
+  }
+
+  // Now find the cluster with maximum variance.
+  arma::u32 maxVarCluster;
+  variances.max(maxVarCluster);
+
+  // Now, inside this cluster, find the point which is furthest away.
+  size_t furthestPoint = data.n_cols;
+  double maxDistance = 0;
+  for (size_t i = 0; i < data.n_cols; i++)
+  {
+    if (assignments[i] == maxVarCluster)
+    {
+      arma::vec diff = data.col(i) - centroids.col(maxVarCluster);
+      double distance = var(diff);
+
+      if (distance > maxDistance)
+      {
+        maxDistance = distance;
+        furthestPoint = i;
+      }
+    }
+  }
+
+  // Take that point and add it to the empty cluster.
+  clusterCounts[maxVarCluster]--;
+  clusterCounts[emptyCluster]++;
+  assignments[furthestPoint] = emptyCluster;
+
+  return 1; // We only changed one point.
+}

Added: mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -0,0 +1,49 @@
+/**
+ * @file max_variance_new_cluster.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the EmptyClusterPolicy policy class for K-Means.  When
+ * an empty cluster is detected, the point furthest from the centroid of the
+ * cluster with maximum variance is taken to be a new cluster.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
+#define __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * When an empty cluster is detected, this class takes the point furthest from
+ * the centroid of the cluster with maximum variance as a new cluster.
+ */
+class MaxVarianceNewCluster
+{
+ public:
+  //! Default constructor required by EmptyClusterPolicy.
+  MaxVarianceNewCluster() { }
+
+  /**
+   * Take the point furthest from the centroid of the cluster with maximum
+   * variance to be a new cluster.
+   *
+   * @param data Dataset on which clustering is being performed.
+   * @param emptyCluster Index of cluster which is empty.
+   * @param centroids Centroids of each cluster (one per column).
+   * @param clusterCounts Number of points in each cluster.
+   * @param assignments Cluster assignments of each point.
+   *
+   * @return Number of points changed.
+   */
+  static size_t EmptyCluster(const arma::mat& data,
+                             const size_t emptyCluster,
+                             const arma::mat& centroids,
+                             arma::Col<size_t>& clusterCounts,
+                             arma::Col<size_t>& assignments);
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/kmeans/random_partition.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/random_partition.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/kmeans/random_partition.hpp	2011-11-27 03:22:05 UTC (rev 10421)
@@ -0,0 +1,50 @@
+/**
+ * @file random_partition.hpp
+ * @author Ryan Curtin
+ *
+ * Very simple partitioner which partitions the data randomly into the number of
+ * desired clusters.  Used as the default InitialPartitionPolicy for KMeans.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
+#define __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * A very simple partitioner which partitions the data randomly into the number
+ * of desired clusters.  It has no parameters, and so an instance of the class
+ * is not even necessary.
+ */
+class RandomPartition
+{
+ public:
+  //! Empty constructor, required by the InitialPartitionPolicy policy.
+  RandomPartition() { }
+
+  /**
+   * Partition the given dataset into the given number of clusters.  Assignments
+   * are random, and the number of points in each cluster should be equal (or
+   * approximately equal).
+   *
+   * @param data Dataset to partition.
+   * @param clusters Number of clusters to split dataset into.
+   * @param assignments Vector to store cluster assignments into.  Values will
+   *     be between 0 and (clusters - 1).
+   */
+  inline static void Cluster(const arma::mat& data,
+                             const size_t clusters,
+                             arma::Col<size_t>& assignments)
+  {
+    // Implementation is so simple we'll put it here in the header file.
+    assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
+        (clusters - 1), data.n_cols));
+  }
+};
+
+};
+};
+
+#endif




More information about the mlpack-svn mailing list