[mlpack-svn] r17235 - mlpack/trunk/src/mlpack/methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 10 16:08:15 EDT 2014


Author: rcurtin
Date: Fri Oct 10 16:08:14 2014
New Revision: 17235

Log:
Add implementation of Hamerly's algorithm.


Added:
   mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt

Modified: mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/kmeans/CMakeLists.txt	Fri Oct 10 16:08:14 2014
@@ -4,6 +4,8 @@
   allow_empty_clusters.hpp
   elkan_kmeans.hpp
   elkan_kmeans_impl.hpp
+  hamerly_kmeans.hpp
+  hamerly_kmeans_impl.hpp
   kmeans.hpp
   kmeans_impl.hpp
   max_variance_new_cluster.hpp

Added: mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans.hpp	Fri Oct 10 16:08:14 2014
@@ -0,0 +1,63 @@
+/**
+ * @file hamerly_kmeans.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Greg Hamerly's algorithm for k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_HPP
+#define __MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType>
+class HamerlyKMeans
+{
+ public:
+  /**
+   * Construct the HamerlyKMeans object, which must store several sets of
+   * bounds.
+   */
+  HamerlyKMeans(const MatType& dataset, MetricType& metric);
+
+  /**
+   * Run a single iteration of Hamerly's algorithm, updating the given centroids
+   * into the newCentroids matrix.
+   *
+   * @param centroids Current cluster centroids.
+   * @param newCentroids New cluster centroids.
+   * @param counts Current counts, to be overwritten with new counts.
+   */
+  double Iterate(const arma::mat& centroids,
+                 arma::mat& newCentroids,
+                 arma::Col<size_t>& counts);
+
+  size_t DistanceCalculations() const { return distanceCalculations; }
+
+ private:
+  //! The dataset.
+  const MatType& dataset;
+  //! The instantiated metric.
+  MetricType& metric;
+
+  //! Minimum cluster distances from each cluster.
+  arma::vec minClusterDistances;
+
+  //! Upper bounds for each point.
+  arma::vec upperBounds;
+  //! Lower bounds for each point.
+  arma::vec lowerBounds;
+  //! Assignments for each point.
+  arma::Col<size_t> assignments;
+
+  //! Track distance calculations.
+  size_t distanceCalculations;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+// Include implementation.
+#include "hamerly_kmeans_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp	Fri Oct 10 16:08:14 2014
@@ -0,0 +1,164 @@
+/**
+ * @file hamerly_kmeans_impl.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Greg Hamerly's algorithm for k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_IMPL_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType>
+HamerlyKMeans<MetricType, MatType>::HamerlyKMeans(const MatType& dataset,
+                                                  MetricType& metric) :
+    dataset(dataset),
+    metric(metric),
+    distanceCalculations(0)
+{
+  // Nothing to do.
+}
+
+template<typename MetricType, typename MatType>
+double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
+                                                   arma::mat& newCentroids,
+                                                   arma::Col<size_t>& counts)
+{
+  // If this is the first iteration, we need to set all the bounds.
+  if (minClusterDistances.n_elem != centroids.n_cols)
+  {
+    upperBounds.set_size(dataset.n_cols);
+    upperBounds.fill(DBL_MAX);
+    lowerBounds.zeros(dataset.n_cols);
+    assignments.zeros(dataset.n_cols);
+    minClusterDistances.set_size(centroids.n_cols);
+  }
+
+  // Reset new centroids.
+  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  counts.zeros(centroids.n_cols);
+
+  // Calculate minimum intra-cluster distance for each cluster.
+  minClusterDistances.fill(DBL_MAX);
+  for (size_t i = 0; i < centroids.n_cols; ++i)
+  {
+    for (size_t j = i + 1; j < centroids.n_cols; ++j)
+    {
+      const double dist = metric.Evaluate(centroids.col(i), centroids.col(j));
+      ++distanceCalculations;
+
+      // Update bounds, if this intra-cluster distance is smaller.
+      if (dist < minClusterDistances(i))
+        minClusterDistances(i) = dist;
+      if (dist < minClusterDistances(j))
+        minClusterDistances(j) = dist;
+    }
+  }
+
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    const double m = std::max(minClusterDistances(assignments[i]) / 2.0,
+                              lowerBounds(i));
+
+    // First bound test.
+    if (upperBounds(i) <= m)
+    {
+      newCentroids.col(assignments[i]) += dataset.col(i);
+      ++counts(assignments[i]);
+      continue;
+    }
+
+    // Tighten upper bound.
+    upperBounds(i) = metric.Evaluate(dataset.col(i),
+                                     centroids.col(assignments[i]));
+    ++distanceCalculations;
+
+    // Second bound test.
+    if (upperBounds(i) <= m)
+    {
+      newCentroids.col(assignments[i]) += dataset.col(i);
+      ++counts(assignments[i]);
+      continue;
+    }
+
+    // The bounds failed.  So test against all other clusters.
+    // This is Hamerly's Point-All-Ctrs() function from the paper.
+    for (size_t c = 0; c < centroids.n_cols; ++c)
+    {
+      if (c == assignments[i])
+        continue;
+
+      const double dist = metric.Evaluate(dataset.col(i), centroids.col(c));
+      ++distanceCalculations;
+
+      // Is this a better cluster?  At this point, upperBounds[i] = d(i, c(i)).
+      if (dist < upperBounds(i))
+      {
+        // lowerBounds holds the second closest cluster.
+        lowerBounds(i) = upperBounds(i);
+        upperBounds(i) = dist;
+        assignments[i] = c;
+      }
+      else if (dist < lowerBounds(i))
+      {
+        // This is a closer second-closest cluster.
+        lowerBounds(i) = dist;
+      }
+    }
+
+    // Update new centroids.
+    newCentroids.col(assignments[i]) += dataset.col(i);
+    ++counts(assignments[i]);
+  }
+
+  // Normalize centroids and calculate cluster movement (contains parts of
+  // Move-Centers() and Update-Bounds()).
+  double furthestMovement = 0.0;
+  double secondFurthestMovement = 0.0;
+  size_t furthestMovingCluster = 0;
+  arma::vec centroidMovements(centroids.n_cols);
+  double centroidMovement = 0.0;
+  for (size_t c = 0; c < centroids.n_cols; ++c)
+  {
+    if (counts(c) > 0)
+      newCentroids.col(c) /= counts(c);
+    else
+      newCentroids.col(c).fill(DBL_MAX); // Empty cluster.
+
+    // Calculate movement.
+    const double movement = metric.Evaluate(centroids.col(c),
+                                            newCentroids.col(c));
+    centroidMovements(c) = movement;
+    centroidMovement += std::pow(movement, 2.0);
+    ++distanceCalculations;
+
+    if (movement > furthestMovement)
+    {
+      secondFurthestMovement = furthestMovement;
+      furthestMovement = movement;
+      furthestMovingCluster = c;
+    }
+    else if (movement > secondFurthestMovement)
+    {
+      secondFurthestMovement = movement;
+    }
+  }
+
+  // Now update bounds (lines 3-8 of Update-Bounds()).
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    upperBounds(i) += centroidMovements(assignments[i]);
+    if (assignments[i] == furthestMovingCluster)
+      lowerBounds(i) -= secondFurthestMovement;
+    else
+      lowerBounds(i) -= furthestMovement;
+  }
+
+  return std::sqrt(centroidMovement);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif



More information about the mlpack-svn mailing list