[mlpack-svn] r10411 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Nov 25 20:41:07 EST 2011


Author: rcurtin
Date: 2011-11-25 20:41:07 -0500 (Fri, 25 Nov 2011)
New Revision: 10411

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
Log:
Adapt to new KMeans API.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2011-11-26 01:13:17 UTC (rev 10410)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2011-11-26 01:41:07 UTC (rev 10411)
@@ -49,12 +49,52 @@
 
   best_l = -DBL_MAX;
 
+  KMeans k; // Default KMeans parameters, for now.
+
   // We will perform ten trials, and then save the trial with the best result
   // as our trained model.
   for (size_t iter = 0; iter < 10; iter++)
   {
-    KMeans(data, gaussians, means_trial, covariances_trial, weights_trial);
+    arma::Col<size_t> assignments;
 
+    k.Cluster(data, gaussians, assignments);
+
+    // Clear the weights, covariances, and means, before we recalculate them.
+    weights_trial.zeros();
+    for (size_t i = 0; i < gaussians; i++)
+    {
+      means_trial[i].zeros();
+      covariances_trial[i].zeros();
+    }
+
+    // From the assignments, generate our means, covariances, and weights.
+    for (size_t i = 0; i < data.n_cols; i++)
+    {
+      size_t cluster = assignments[i];
+
+      // Add this to the relevant mean.
+      means_trial[cluster] += data.col(i);
+
+      // And add it to the relative covariance matrix.
+      covariances_trial[cluster] += data.col(i) * trans(data.col(i));
+
+      // Now add one to the weights (we will normalize).
+      weights_trial[cluster]++;
+    }
+
+    // Now normalize the means, covariances, and weights.
+    for (size_t i = 0; i < gaussians; i++)
+    {
+      // Normalize mean.
+      means_trial[i] /= weights_trial[i];
+
+      // Normalize covariance (use unbiased estimator).
+      covariances_trial[i] /= (weights_trial[i] - 1);
+    }
+
+    // Finally, normalize weights.
+    weights_trial /= accu(weights_trial);
+
     l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
 
     Log::Info << "K-means log-likelihood: " << l << std::endl;




More information about the mlpack-svn mailing list