[mlpack-svn] r11685 - in mlpack/trunk/src/mlpack: core/tree methods/kmeans
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 16:26:44 EST 2012
Author: jcline3
Date: 2012-03-01 16:26:44 -0500 (Thu, 01 Mar 2012)
New Revision: 11685
Modified:
mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp
Log:
Mostly, hopefully, working implementation of Pelleg and Moore's mrkd tree k-means.
Does not yet actually modify the assignments vector which holds the cluster labels.
Modified: mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp 2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp 2012-03-01 21:26:44 UTC (rev 11685)
@@ -43,14 +43,13 @@
begin(begin),
count(count)
{
- centerOfMass = dataset.row(begin);
+ centerOfMass = dataset.col(begin);
for(size_t i = begin+1; i < begin+count; ++i)
- centerOfMass += dataset.row(i);
- centerOfMass /= count;
+ centerOfMass += dataset.col(i);
sumOfSquaredNorms = 0.0;
for(size_t i = begin; i < begin+count; ++i)
- sumOfSquaredNorms += arma::norm(dataset.row(i), 2);
+ sumOfSquaredNorms += arma::norm(dataset.col(i), 2);
}
/**
@@ -78,9 +77,12 @@
{
sumOfSquaredNorms = leftStat.sumOfSquaredNorms + rightStat.sumOfSquaredNorms;
+ /*
centerOfMass = ((leftStat.centerOfMass * leftStat.count) +
(rightStat.centerOfMass * rightStat.count)) /
(leftStat.count + rightStat.count);
+ */
+ centerOfMass = leftStat.centerOfMass + rightStat.centerOfMass;
}
//! The data points this object contains
@@ -96,9 +98,13 @@
// Computed statistics
//! The center of mass for this dataset
- arma::rowvec centerOfMass;
+ arma::colvec centerOfMass;
//! The sum of the squared Euclidian norms for this dataset
double sumOfSquaredNorms;
+
+ // There may be a better place to store this -- HRectBound?
+ //! The index of the dominating centroid of the associated hyperrectangle
+ size_t dominatingCentroid;
};
}; // namespace tree
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp 2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp 2012-03-01 21:26:44 UTC (rev 11685)
@@ -102,6 +102,10 @@
void Cluster(const MatType& data,
const size_t clusters,
arma::Col<size_t>& assignments) const;
+ template<typename MatType>
+ void FastCluster(MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const;
/**
* Return the overclustering factor.
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp 2012-03-01 21:26:44 UTC (rev 11685)
@@ -7,8 +7,13 @@
*/
#include "kmeans.hpp"
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/hrectbound.hpp>
+#include <mlpack/core/tree/mrkd_statistic.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include <stack>
+
namespace mlpack {
namespace kmeans {
@@ -45,6 +50,210 @@
}
}
+template<typename DistanceMetric,
+ typename InitialPartitionPolicy,
+ typename EmptyClusterPolicy>
+template<typename MatType>
+void KMeans<
+ DistanceMetric,
+ InitialPartitionPolicy,
+ EmptyClusterPolicy>::
+FastCluster(MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const
+{
+ size_t actualClusters = size_t(overclusteringFactor * clusters);
+ if (actualClusters > data.n_cols)
+ {
+ Log::Warn << "KMeans::Cluster(): overclustering factor is too large. No "
+ << "overclustering will be done." << std::endl;
+ actualClusters = clusters;
+ }
+
+ // TODO: remove
+ // Scale the data to [0,1]
+ if(0){
+ arma::rowvec min = arma::min(data, 0);
+ data = (data - arma::ones<arma::colvec>(data.n_rows) * min) / (arma::ones<arma::colvec>(data.n_rows) * (arma::max(data,0) - min));
+ for(size_t i = 0; i < data.n_cols; ++i)
+ for(size_t j = 0; j < data.n_rows; ++j)
+ assert(data(j,i) >= 0 && data(j,i) <= 1);
+ }
+
+ // Centroids of each cluster. Each column corresponds to a centroid.
+ MatType centroids(data.n_rows, actualClusters);
+
+ // Counts of points in each cluster.
+ arma::Col<size_t> counts(actualClusters);
+ counts.zeros();
+
+ // Build the mrkd-tree on this dataset
+ tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic> tree(data, 1);
+ // A pointer for traversing the mrkd-tree
+ tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>* node;
+
+ // We use this to store the furtherst point in a hyperrectangle from a given
+ // vector.
+ arma::colvec p(data.n_rows);
+
+ // Make random centroids and fit them into the root hyperrectangle.
+ {
+ centroids.randu();
+ bound::HRectBound<2>& bound = tree.Bound();
+ size_t dim = bound.Dim();
+ for(size_t i = 0; i < dim; ++i) {
+ double min = bound[i].Lo();
+ double max = bound[i].Hi();
+ for(size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ if(centroids(i,j) < min)
+ centroids(i,j) = min;
+ else if(centroids(i,j) > max)
+ centroids(i,j) = max;
+ }
+ }
+ }
+
+ // Instead of retraversing the tree after an iteration, we will update centroid
+ // positions in this matrix, which also prevents clobbering our centroids from
+ // the previous iteration.
+ MatType newCentroids(centroids.n_rows, centroids.n_cols);
+
+ size_t iteration = 0;
+ size_t changedAssignments = 0;
+ do
+ {
+ // Keep track of what iteration we are on.
+ ++iteration;
+ changedAssignments = 0;
+ newCentroids.zeros();
+
+ // Create a stack for traversing the mrkd-tree
+ std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>,
+ tree::MRKDStatistic>* > stack;
+ // Add the root node of the tree to the stack
+ stack.push(&tree);
+
+ while (!stack.empty())
+ {
+ node = stack.top();
+ stack.pop();
+
+ tree::MRKDStatistic& mrkd = node->Stat();
+
+ size_t minIndex = 0;
+
+ // If this node is a leaf, then we calculate the distance from
+ // the centroids to every point the node contains.
+ if (node->IsLeaf())
+ {
+ for (size_t i = mrkd.begin; i < mrkd.count + mrkd.begin; ++i)
+ {
+ // Initialize minDistance to be nonzero.
+ double minDistance = metric::SquaredEuclideanDistance::Evaluate(
+ data.col(i), centroids.col(0));
+ // Find the minimal distance centroid for this point.
+ for (size_t j = 1; j < centroids.n_cols; ++j)
+ {
+ double distance = metric::SquaredEuclideanDistance::Evaluate(
+ data.col(i), centroids.col(j));
+ if ( minDistance > distance )
+ {
+ minIndex = j;
+ minDistance = distance;
+ }
+ }
+
+ ++counts[minIndex];
+ newCentroids.col(minIndex) += data.col(i);
+ }
+ }
+ // If this node is not a leaf, then we continue trying to find dominant
+ // centroids
+ else
+ {
+ bound::HRectBound<2>& bound = node->Bound();
+
+ bool noDomination = false;
+
+ // There was no centroid inside this hyperrectangle.
+ // We must determine if an external centroid dominates it.
+ for(size_t i = 0; i < centroids.n_cols; ++i)
+ {
+ noDomination = false;
+ for(size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ if(j == i)
+ continue;
+
+ for(size_t k = 0; k < p.n_rows; ++k)
+ {
+ p(k) = (centroids(k,j) > centroids(k,i)) ?
+ bound[k].Hi() : bound[k].Lo();
+ }
+
+ double distancei = metric::SquaredEuclideanDistance::Evaluate(
+ p.col(0), centroids.col(i));
+ double distancej = metric::SquaredEuclideanDistance::Evaluate(
+ p.col(0), centroids.col(j));
+
+ if(distancei >= distancej)
+ {
+ noDomination = true;
+ break;
+ }
+
+ }
+
+ // We identified a centroid that dominates this hyperrectangle.
+ if(!noDomination)
+ {
+ mrkd.dominatingCentroid = i;
+ counts[i] += mrkd.count;
+ newCentroids.col(minIndex) += mrkd.centerOfMass;
+ break;
+ }
+ }
+
+ // If we did not find a dominating centroid then we fall through to the
+ // default case, where we add the children of this node to the stack.
+ if(noDomination)
+ {
+ stack.push(node->Left());
+ stack.push(node->Right());
+ }
+ }
+
+ }
+
+ for(size_t i = 0; i < centroids.n_cols; ++i)
+ {
+ if(counts(i))
+ // Divide by the number of points assigned to this centroid so that we
+ // have the actual center of mass.
+ newCentroids.col(i) /= counts(i);
+
+ // TODO: switch to faster way of keeping track of changed assignments
+ if(changedAssignments != 0)
+ {
+ for(size_t j = 0; j < centroids.n_rows; ++j)
+ {
+ if(fabs(newCentroids(j,i) - centroids(j,i)) > 1e-5)
+ {
+ ++changedAssignments;
+ break;
+ }
+ }
+ }
+ }
+
+ // Update the centroids' positions.
+ centroids = newCentroids;
+ } while (changedAssignments > 0 && iteration != maxIterations);
+
+ std::cout << centroids << '\n' << counts << std::endl;
+}
+
/**
* Perform K-Means clustering on the data, returning a list of cluster
* assignments.
Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp 2012-03-01 16:58:44 UTC (rev 11684)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_main.cpp 2012-03-01 21:26:44 UTC (rev 11685)
@@ -37,6 +37,7 @@
PARAM_INT("max_iterations", "Maximum number of iterations before K-Means "
"terminates.", "m", 1000);
PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_FLAG("fast_kmeans", "Use the experimental fast k-means algorithm by Pelleg and Moore", "f")
int main(int argc, char** argv)
{
@@ -90,13 +91,19 @@
KMeans<metric::SquaredEuclideanDistance, RandomPartition,
AllowEmptyClusters> k(maxIterations, overclustering);
- k.Cluster(dataset, clusters, assignments);
+ if(CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments);
}
else
{
KMeans<> k(maxIterations, overclustering);
- k.Cluster(dataset, clusters, assignments);
+ if(CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments);
}
// Now figure out what to do with our results.
More information about the mlpack-svn
mailing list