[mlpack-svn] r11793 - mlpack/trunk/src/mlpack/methods/kmeans

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 8 00:20:59 EST 2012


Author: jcline3
Date: 2012-03-08 00:20:58 -0500 (Thu, 08 Mar 2012)
New Revision: 11793

Modified:
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
Log:
Something's wrong with this


Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-03-08 04:56:41 UTC (rev 11792)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-03-08 05:20:58 UTC (rev 11793)
@@ -80,6 +80,9 @@
         assert(data(j,i) >= 0 && data(j,i) <= 1);
   }
 
+  if (assignments.n_rows != data.n_cols)
+    assignments.resize(data.n_cols);
+
   // Centroids of each cluster.  Each column corresponds to a centroid.
   MatType centroids(data.n_rows, actualClusters);
 
@@ -101,14 +104,14 @@
     centroids.randu();
     bound::HRectBound<2>& bound = tree.Bound();
     size_t dim = bound.Dim();
-    for(size_t i = 0; i < dim; ++i) {
+    for (size_t i = 0; i < dim; ++i) {
       double min = bound[i].Lo();
       double max = bound[i].Hi();
-      for(size_t j = 0; j < centroids.n_cols; ++j)
+      for (size_t j = 0; j < centroids.n_cols; ++j)
       {
-        if(centroids(i,j) < min)
+        if (centroids(i,j) < min)
           centroids(i,j) = min;
-        else if(centroids(i,j) > max)
+        else if (centroids(i,j) > max)
           centroids(i,j) = max;
       }
     }
@@ -119,6 +122,7 @@
   // the previous iteration.
   MatType newCentroids(centroids.n_rows, centroids.n_cols);
 
+  std::cout << data.n_cols << std::endl;
   size_t iteration = 0;
   size_t changedAssignments = 0;
   do 
@@ -127,6 +131,7 @@
     ++iteration;
     changedAssignments = 0;
     newCentroids.zeros();
+    counts.zeros();
 
     // Create a stack for traversing the mrkd-tree
     std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>, 
@@ -147,6 +152,7 @@
       // the centroids to every point the node contains.
       if (node->IsLeaf())
       {
+        //std::cout << "Leaf\t";
         for (size_t i = mrkd.begin; i < mrkd.count + mrkd.begin; ++i)
         {
           // Initialize minDistance to be nonzero.
@@ -164,29 +170,40 @@
             }
           }
 
-          ++counts[minIndex];
           newCentroids.col(minIndex) += data.col(i);
+          ++counts(minIndex);
+          //std::cout << counts(minIndex) << "\t";
+          if (assignments(i) != minIndex)
+          {
+            ++changedAssignments;
+            // TODO: this if should be removed
+            //if(counts(assignments(i)))
+              //--counts(assignments(i));
+            assignments(i) = minIndex;
+          }
         }
+        //std::cout << std::endl;
       }
       // If this node is not a leaf, then we continue trying to find dominant
       // centroids
       else
       {
+        //std::cout << "Parent\t";
         bound::HRectBound<2>& bound = node->Bound();
 
         bool noDomination = false;
 
         // There was no centroid inside this hyperrectangle.
         // We must determine if an external centroid dominates it.
-        for(size_t i = 0; i < centroids.n_cols; ++i) 
+        for (size_t i = 0; i < centroids.n_cols; ++i) 
         {
           noDomination = false;
-          for(size_t j = 0; j < centroids.n_cols; ++j)
+          for (size_t j = 0; j < centroids.n_cols; ++j)
           {
-            if(j == i)
+            if (j == i)
               continue;
 
-            for(size_t k = 0; k < p.n_rows; ++k)
+            for (size_t k = 0; k < p.n_rows; ++k)
             {
               p(k) = (centroids(k,j) > centroids(k,i)) ?
                 bound[k].Hi() : bound[k].Lo();
@@ -197,7 +214,7 @@
             double distancej = metric::SquaredEuclideanDistance::Evaluate(
                 p.col(0), centroids.col(j));
 
-            if(distancei >= distancej)
+            if (distancei >= distancej)
             {
               noDomination = true;
               break;
@@ -206,19 +223,36 @@
           }
 
           // We identified a centroid that dominates this hyperrectangle.
-          if(!noDomination)
+          if (!noDomination)
           {
+            //std::cout << "Domination\t";
+            newCentroids.col(minIndex) += mrkd.centerOfMass;
+            counts(i) += mrkd.count;
+            //std::cout << counts(i) << std::endl;
+            // Update all assignments for this node
+            const size_t begin = node->Begin();
+            const size_t end = node->End();
+            for (size_t j = begin; j < end; ++j)
+            {
+              if (assignments(j) != i)
+              {
+                ++changedAssignments;
+                //if(counts(assignments(j)))
+                  //--counts(assignments(j));
+                //++counts(i);
+                assignments(j) = i;
+              }
+            }
             mrkd.dominatingCentroid = i;
-            counts[i] += mrkd.count;
-            newCentroids.col(minIndex) += mrkd.centerOfMass;
             break;
           }
         }
 
         // If we did not find a dominating centroid then we fall through to the
         // default case, where we add the children of this node to the stack.
-        if(noDomination)
+        if (noDomination)
         {
+          //std::cout << "No Domination" << std::endl;
           stack.push(node->Left());
           stack.push(node->Right());
         }
@@ -226,32 +260,25 @@
 
     }
 
-    for(size_t i = 0; i < centroids.n_cols; ++i)
+    for (size_t i = 0; i < centroids.n_cols; ++i)
     {
-      if(counts(i))
+      if (counts(i)) {
         // Divide by the number of points assigned to this centroid so that we
-        // have the actual center of mass.
-        newCentroids.col(i) /= counts(i);
-
-      // TODO: switch to faster way of keeping track of changed assignments
-      if(changedAssignments != 0)
-      {
-        for(size_t j = 0; j < centroids.n_rows; ++j)
-        {
-          if(fabs(newCentroids(j,i) - centroids(j,i)) > 1e-5)
-          {
-            ++changedAssignments;
-            break;
-          }
-        }
+        // have the actual center of mass and update centroids' positions.
+        centroids.col(i) = newCentroids.col(i) / counts(i);
       }
     }
+    size_t count = 0;
+    for(size_t k = 0; k < counts.n_rows; ++k)
+    {
+      std::cout << counts(k) << '\t';
+      count += counts(k);
+    }
+    std::cout << '\n' << count <<'\t'<< data.n_cols<< std::endl;
+    assert(count <= data.n_cols);
+  } while(0);
+  //} while (changedAssignments > 0 && iteration != maxIterations);
 
-    // Update the centroids' positions.
-    centroids = newCentroids;
-  } while (changedAssignments > 0 && iteration != maxIterations);
-
-  std::cout << centroids << '\n' << counts << std::endl;
 }
 
 /**




More information about the mlpack-svn mailing list