[mlpack-svn] r11793 - mlpack/trunk/src/mlpack/methods/kmeans
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 8 00:20:59 EST 2012
Author: jcline3
Date: 2012-03-08 00:20:58 -0500 (Thu, 08 Mar 2012)
New Revision: 11793
Modified:
mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
Log:
Something's wrong with this
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-03-08 04:56:41 UTC (rev 11792)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-03-08 05:20:58 UTC (rev 11793)
@@ -80,6 +80,9 @@
assert(data(j,i) >= 0 && data(j,i) <= 1);
}
+ if (assignments.n_rows != data.n_cols)
+ assignments.resize(data.n_cols);
+
// Centroids of each cluster. Each column corresponds to a centroid.
MatType centroids(data.n_rows, actualClusters);
@@ -101,14 +104,14 @@
centroids.randu();
bound::HRectBound<2>& bound = tree.Bound();
size_t dim = bound.Dim();
- for(size_t i = 0; i < dim; ++i) {
+ for (size_t i = 0; i < dim; ++i) {
double min = bound[i].Lo();
double max = bound[i].Hi();
- for(size_t j = 0; j < centroids.n_cols; ++j)
+ for (size_t j = 0; j < centroids.n_cols; ++j)
{
- if(centroids(i,j) < min)
+ if (centroids(i,j) < min)
centroids(i,j) = min;
- else if(centroids(i,j) > max)
+ else if (centroids(i,j) > max)
centroids(i,j) = max;
}
}
@@ -119,6 +122,7 @@
// the previous iteration.
MatType newCentroids(centroids.n_rows, centroids.n_cols);
+ std::cout << data.n_cols << std::endl;
size_t iteration = 0;
size_t changedAssignments = 0;
do
@@ -127,6 +131,7 @@
++iteration;
changedAssignments = 0;
newCentroids.zeros();
+ counts.zeros();
// Create a stack for traversing the mrkd-tree
std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>,
@@ -147,6 +152,7 @@
// the centroids to every point the node contains.
if (node->IsLeaf())
{
+ //std::cout << "Leaf\t";
for (size_t i = mrkd.begin; i < mrkd.count + mrkd.begin; ++i)
{
// Initialize minDistance to be nonzero.
@@ -164,29 +170,40 @@
}
}
- ++counts[minIndex];
newCentroids.col(minIndex) += data.col(i);
+ ++counts(minIndex);
+ //std::cout << counts(minIndex) << "\t";
+ if (assignments(i) != minIndex)
+ {
+ ++changedAssignments;
+ // TODO: this if should be removed
+ //if(counts(assignments(i)))
+ //--counts(assignments(i));
+ assignments(i) = minIndex;
+ }
}
+ //std::cout << std::endl;
}
// If this node is not a leaf, then we continue trying to find dominant
// centroids
else
{
+ //std::cout << "Parent\t";
bound::HRectBound<2>& bound = node->Bound();
bool noDomination = false;
// There was no centroid inside this hyperrectangle.
// We must determine if an external centroid dominates it.
- for(size_t i = 0; i < centroids.n_cols; ++i)
+ for (size_t i = 0; i < centroids.n_cols; ++i)
{
noDomination = false;
- for(size_t j = 0; j < centroids.n_cols; ++j)
+ for (size_t j = 0; j < centroids.n_cols; ++j)
{
- if(j == i)
+ if (j == i)
continue;
- for(size_t k = 0; k < p.n_rows; ++k)
+ for (size_t k = 0; k < p.n_rows; ++k)
{
p(k) = (centroids(k,j) > centroids(k,i)) ?
bound[k].Hi() : bound[k].Lo();
@@ -197,7 +214,7 @@
double distancej = metric::SquaredEuclideanDistance::Evaluate(
p.col(0), centroids.col(j));
- if(distancei >= distancej)
+ if (distancei >= distancej)
{
noDomination = true;
break;
@@ -206,19 +223,36 @@
}
// We identified a centroid that dominates this hyperrectangle.
- if(!noDomination)
+ if (!noDomination)
{
+ //std::cout << "Domination\t";
+ newCentroids.col(minIndex) += mrkd.centerOfMass;
+ counts(i) += mrkd.count;
+ //std::cout << counts(i) << std::endl;
+ // Update all assignments for this node
+ const size_t begin = node->Begin();
+ const size_t end = node->End();
+ for (size_t j = begin; j < end; ++j)
+ {
+ if (assignments(j) != i)
+ {
+ ++changedAssignments;
+ //if(counts(assignments(j)))
+ //--counts(assignments(j));
+ //++counts(i);
+ assignments(j) = i;
+ }
+ }
mrkd.dominatingCentroid = i;
- counts[i] += mrkd.count;
- newCentroids.col(minIndex) += mrkd.centerOfMass;
break;
}
}
// If we did not find a dominating centroid then we fall through to the
// default case, where we add the children of this node to the stack.
- if(noDomination)
+ if (noDomination)
{
+ //std::cout << "No Domination" << std::endl;
stack.push(node->Left());
stack.push(node->Right());
}
@@ -226,32 +260,25 @@
}
- for(size_t i = 0; i < centroids.n_cols; ++i)
+ for (size_t i = 0; i < centroids.n_cols; ++i)
{
- if(counts(i))
+ if (counts(i)) {
// Divide by the number of points assigned to this centroid so that we
- // have the actual center of mass.
- newCentroids.col(i) /= counts(i);
-
- // TODO: switch to faster way of keeping track of changed assignments
- if(changedAssignments != 0)
- {
- for(size_t j = 0; j < centroids.n_rows; ++j)
- {
- if(fabs(newCentroids(j,i) - centroids(j,i)) > 1e-5)
- {
- ++changedAssignments;
- break;
- }
- }
+ // have the actual center of mass and update centroids' positions.
+ centroids.col(i) = newCentroids.col(i) / counts(i);
}
}
+ size_t count = 0;
+ for(size_t k = 0; k < counts.n_rows; ++k)
+ {
+ std::cout << counts(k) << '\t';
+ count += counts(k);
+ }
+ std::cout << '\n' << count <<'\t'<< data.n_cols<< std::endl;
+ assert(count <= data.n_cols);
+ } while(0);
+ //} while (changedAssignments > 0 && iteration != maxIterations);
- // Update the centroids' positions.
- centroids = newCentroids;
- } while (changedAssignments > 0 && iteration != maxIterations);
-
- std::cout << centroids << '\n' << counts << std::endl;
}
/**
More information about the mlpack-svn
mailing list