[mlpack-svn] r11685 - in mlpack/trunk/src/mlpack: core/tree methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 16:26:44 EST 2012


Author: jcline3
Date: 2012-03-01 16:26:44 -0500 (Thu, 01 Mar 2012)
New Revision: 11685

Modified:
   mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp
Log:
Mostly, hopefully, working implementation of Pelleg and Moore's mrkd tree k-means.

Does not yet actually modify the assignments vector which holds the cluster labels.


Modified: mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp	2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp	2012-03-01 21:26:44 UTC (rev 11685)
@@ -43,14 +43,13 @@
       begin(begin),
       count(count)
     {
-      centerOfMass = dataset.row(begin);
+      centerOfMass = dataset.col(begin);
       for(size_t i = begin+1; i < begin+count; ++i)
-        centerOfMass += dataset.row(i);
-      centerOfMass /= count;
+        centerOfMass += dataset.col(i);
 
       sumOfSquaredNorms = 0.0;
       for(size_t i = begin; i < begin+count; ++i)
-        sumOfSquaredNorms += arma::norm(dataset.row(i), 2);
+        sumOfSquaredNorms += arma::norm(dataset.col(i), 2);
     }
 
     /**
@@ -78,9 +77,12 @@
     {
       sumOfSquaredNorms = leftStat.sumOfSquaredNorms + rightStat.sumOfSquaredNorms;
 
+      /*
       centerOfMass = ((leftStat.centerOfMass * leftStat.count) +
                       (rightStat.centerOfMass * rightStat.count)) /
                       (leftStat.count + rightStat.count);
+      */
+      centerOfMass = leftStat.centerOfMass + rightStat.centerOfMass;
     }
 
     //! The data points this object contains
@@ -96,9 +98,13 @@
 
     // Computed statistics
     //! The center of mass for this dataset
-    arma::rowvec centerOfMass;
+    arma::colvec centerOfMass;
     //! The sum of the squared Euclidian norms for this dataset
     double sumOfSquaredNorms;
+		
+		// There may be a better place to store this -- HRectBound?
+		//! The index of the dominating centroid of the associated hyperrectangle
+		size_t dominatingCentroid;
 };
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2012-03-01 21:26:44 UTC (rev 11685)
@@ -102,6 +102,10 @@
   void Cluster(const MatType& data,
                const size_t clusters,
                arma::Col<size_t>& assignments) const;
+  template<typename MatType>
+  void FastCluster(MatType& data,
+               const size_t clusters,
+               arma::Col<size_t>& assignments) const;
 
   /**
    * Return the overclustering factor.

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-03-01 21:26:44 UTC (rev 11685)
@@ -7,8 +7,13 @@
  */
 #include "kmeans.hpp"
 
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/hrectbound.hpp>
+#include <mlpack/core/tree/mrkd_statistic.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 
+#include <stack>
+
 namespace mlpack {
 namespace kmeans {
 
@@ -45,6 +50,210 @@
   }
 }
 
+template<typename DistanceMetric,
+         typename InitialPartitionPolicy,
+         typename EmptyClusterPolicy>
+template<typename MatType>
+void KMeans<
+    DistanceMetric,
+    InitialPartitionPolicy,
+    EmptyClusterPolicy>::
+FastCluster(MatType& data,
+            const size_t clusters,
+            arma::Col<size_t>& assignments) const
+{
+  size_t actualClusters = size_t(overclusteringFactor * clusters);
+  if (actualClusters > data.n_cols)
+  {
+    Log::Warn << "KMeans::Cluster(): overclustering factor is too large.  No "
+        << "overclustering will be done." << std::endl;
+    actualClusters = clusters;
+  }
+
+  // TODO: remove
+  // Scale the data to [0,1]
+  if(0){
+    arma::rowvec min = arma::min(data, 0);
+    data = (data - arma::ones<arma::colvec>(data.n_rows) * min) / (arma::ones<arma::colvec>(data.n_rows) * (arma::max(data,0) - min));
+    for(size_t i = 0; i < data.n_cols; ++i)
+      for(size_t j = 0; j < data.n_rows; ++j)
+        assert(data(j,i) >= 0 && data(j,i) <= 1);
+  }
+
+  // Centroids of each cluster.  Each column corresponds to a centroid.
+  MatType centroids(data.n_rows, actualClusters);
+
+  // Counts of points in each cluster.
+  arma::Col<size_t> counts(actualClusters);
+  counts.zeros();
+
+  // Build the mrkd-tree on this dataset
+  tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic> tree(data, 1);
+  // A pointer for traversing the mrkd-tree
+  tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>* node;
+
+  // We use this to store the furtherst point in a hyperrectangle from a given
+  // vector.
+  arma::colvec p(data.n_rows);
+
+  // Make random centroids and fit them into the root hyperrectangle.
+  {
+    centroids.randu();
+    bound::HRectBound<2>& bound = tree.Bound();
+    size_t dim = bound.Dim();
+    for(size_t i = 0; i < dim; ++i) {
+      double min = bound[i].Lo();
+      double max = bound[i].Hi();
+      for(size_t j = 0; j < centroids.n_cols; ++j)
+      {
+        if(centroids(i,j) < min)
+          centroids(i,j) = min;
+        else if(centroids(i,j) > max)
+          centroids(i,j) = max;
+      }
+    }
+  }
+
+  // Instead of retraversing the tree after an iteration, we will update centroid
+  // positions in this matrix, which also prevents clobbering our centroids from
+  // the previous iteration.
+  MatType newCentroids(centroids.n_rows, centroids.n_cols);
+
+  size_t iteration = 0;
+  size_t changedAssignments = 0;
+  do 
+  {
+    // Keep track of what iteration we are on.
+    ++iteration;
+    changedAssignments = 0;
+    newCentroids.zeros();
+
+    // Create a stack for traversing the mrkd-tree
+    std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>, 
+                                              tree::MRKDStatistic>* > stack;
+    // Add the root node of the tree to the stack
+    stack.push(&tree);
+
+    while (!stack.empty())
+    {
+      node = stack.top();
+      stack.pop();
+
+      tree::MRKDStatistic& mrkd = node->Stat();
+
+      size_t minIndex = 0;
+
+      // If this node is a leaf, then we calculate the distance from
+      // the centroids to every point the node contains.
+      if (node->IsLeaf())
+      {
+        for (size_t i = mrkd.begin; i < mrkd.count + mrkd.begin; ++i)
+        {
+          // Initialize minDistance to be nonzero.
+          double minDistance = metric::SquaredEuclideanDistance::Evaluate(
+              data.col(i), centroids.col(0));
+          // Find the minimal distance centroid for this point.
+          for (size_t j = 1; j < centroids.n_cols; ++j)
+          {
+            double distance = metric::SquaredEuclideanDistance::Evaluate(
+                data.col(i), centroids.col(j));
+            if ( minDistance > distance )
+            {
+              minIndex = j;
+              minDistance = distance;
+            }
+          }
+
+          ++counts[minIndex];
+          newCentroids.col(minIndex) += data.col(i);
+        }
+      }
+      // If this node is not a leaf, then we continue trying to find dominant
+      // centroids
+      else
+      {
+        bound::HRectBound<2>& bound = node->Bound();
+
+        bool noDomination = false;
+
+        // There was no centroid inside this hyperrectangle.
+        // We must determine if an external centroid dominates it.
+        for(size_t i = 0; i < centroids.n_cols; ++i) 
+        {
+          noDomination = false;
+          for(size_t j = 0; j < centroids.n_cols; ++j)
+          {
+            if(j == i)
+              continue;
+
+            for(size_t k = 0; k < p.n_rows; ++k)
+            {
+              p(k) = (centroids(k,j) > centroids(k,i)) ?
+                bound[k].Hi() : bound[k].Lo();
+            }
+
+            double distancei = metric::SquaredEuclideanDistance::Evaluate(
+                p.col(0), centroids.col(i));
+            double distancej = metric::SquaredEuclideanDistance::Evaluate(
+                p.col(0), centroids.col(j));
+
+            if(distancei >= distancej)
+            {
+              noDomination = true;
+              break;
+            }
+
+          }
+
+          // We identified a centroid that dominates this hyperrectangle.
+          if(!noDomination)
+          {
+            mrkd.dominatingCentroid = i;
+            counts[i] += mrkd.count;
+            newCentroids.col(minIndex) += mrkd.centerOfMass;
+            break;
+          }
+        }
+
+        // If we did not find a dominating centroid then we fall through to the
+        // default case, where we add the children of this node to the stack.
+        if(noDomination)
+        {
+          stack.push(node->Left());
+          stack.push(node->Right());
+        }
+      }
+
+    }
+
+    for(size_t i = 0; i < centroids.n_cols; ++i)
+    {
+      if(counts(i))
+        // Divide by the number of points assigned to this centroid so that we
+        // have the actual center of mass.
+        newCentroids.col(i) /= counts(i);
+
+      // TODO: switch to faster way of keeping track of changed assignments
+      if(changedAssignments != 0)
+      {
+        for(size_t j = 0; j < centroids.n_rows; ++j)
+        {
+          if(fabs(newCentroids(j,i) - centroids(j,i)) > 1e-5)
+          {
+            ++changedAssignments;
+            break;
+          }
+        }
+      }
+    }
+
+    // Update the centroids' positions.
+    centroids = newCentroids;
+  } while (changedAssignments > 0 && iteration != maxIterations);
+
+  std::cout << centroids << '\n' << counts << std::endl;
+}
+
 /**
  * Perform K-Means clustering on the data, returning a list of cluster
  * assignments.

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp	2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp	2012-03-01 21:26:44 UTC (rev 11685)
@@ -37,6 +37,7 @@
 PARAM_INT("max_iterations", "Maximum number of iterations before K-Means "
     "terminates.", "m", 1000);
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_FLAG("fast_kmeans", "Use the experimental fast k-means algorithm by Pelleg and Moore", "f")
 
 int main(int argc, char** argv)
 {
@@ -90,13 +91,19 @@
     KMeans<metric::SquaredEuclideanDistance, RandomPartition,
         AllowEmptyClusters> k(maxIterations, overclustering);
 
-    k.Cluster(dataset, clusters, assignments);
+		if(CLI::HasParam("fast_kmeans"))
+			k.FastCluster(dataset, clusters, assignments);
+		else
+			k.Cluster(dataset, clusters, assignments);
   }
   else
   {
     KMeans<> k(maxIterations, overclustering);
 
-    k.Cluster(dataset, clusters, assignments);
+		if(CLI::HasParam("fast_kmeans"))
+			k.FastCluster(dataset, clusters, assignments);
+		else
+			k.Cluster(dataset, clusters, assignments);
   }
 
   // Now figure out what to do with our results.




More information about the mlpack-svn mailing list