[mlpack-svn] r10411 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Nov 25 20:41:07 EST 2011
Author: rcurtin
Date: 2011-11-25 20:41:07 -0500 (Fri, 25 Nov 2011)
New Revision: 10411
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
Log:
Adapt to new KMeans API.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp 2011-11-26 01:13:17 UTC (rev 10410)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp 2011-11-26 01:41:07 UTC (rev 10411)
@@ -49,12 +49,52 @@
best_l = -DBL_MAX;
+ KMeans k; // Default KMeans parameters, for now.
+
// We will perform ten trials, and then save the trial with the best result
// as our trained model.
for (size_t iter = 0; iter < 10; iter++)
{
- KMeans(data, gaussians, means_trial, covariances_trial, weights_trial);
+ arma::Col<size_t> assignments;
+ k.Cluster(data, gaussians, assignments);
+
+ // Clear the weights, covariances, and means, before we recalculate them.
+ weights_trial.zeros();
+ for (size_t i = 0; i < gaussians; i++)
+ {
+ means_trial[i].zeros();
+ covariances_trial[i].zeros();
+ }
+
+ // From the assignments, generate our means, covariances, and weights.
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ size_t cluster = assignments[i];
+
+ // Add this to the relevant mean.
+ means_trial[cluster] += data.col(i);
+
+ // And add it to the relative covariance matrix.
+ covariances_trial[cluster] += data.col(i) * trans(data.col(i));
+
+ // Now add one to the weights (we will normalize).
+ weights_trial[cluster]++;
+ }
+
+ // Now normalize the means, covariances, and weights.
+ for (size_t i = 0; i < gaussians; i++)
+ {
+ // Normalize mean.
+ means_trial[i] /= weights_trial[i];
+
+ // Normalize covariance (use unbiased estimator).
+ covariances_trial[i] /= (weights_trial[i] - 1);
+ }
+
+ // Finally, normalize weights.
+ weights_trial /= accu(weights_trial);
+
l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
Log::Info << "K-means log-likelihood: " << l << std::endl;
More information about the mlpack-svn
mailing list