[mlpack-svn] r14883 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Apr 9 14:05:52 EDT 2013


Author: rcurtin
Date: 2013-04-09 14:05:52 -0400 (Tue, 09 Apr 2013)
New Revision: 14883

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
Log:
Add perturbation to covariances when necessary to prevent zero-valued covariance
matrices.


Modified: mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	2013-04-09 18:04:14 UTC (rev 14882)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	2013-04-09 18:05:52 UTC (rev 14883)
@@ -69,6 +69,13 @@
           trans(condProb.col(i)));
 
       covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
+
+      if (accu(covariances[i]) == 0)
+      {
+        Log::Debug << "Covariance " << i << " sums to zero!  Adding "
+            << " perturbation." << std::endl;
+        covariances[i].diag() += 1e-50;
+      }
     }
 
     // Calculate the new values for omega using the updated conditional
@@ -145,6 +152,13 @@
           trans(condProb.col(i) % probabilities));
 
       covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
+
+      if (accu(covariances[i]) == 0)
+      {
+        Log::Debug << "Covariance " << i << " sums to zero!  Adding "
+            << " perturbation." << std::endl;
+        covariances[i].diag() += 1e-50;
+      }
     }
 
     // Calculate the new values for omega using the updated conditional
@@ -178,7 +192,7 @@
   {
     means[i].zeros();
     covariances[i].zeros();
-    covariances[i].diag().fill(1e-200);
+    covariances[i].diag().fill(1e-50);
   }
 
   // From the assignments, generate our means, covariances, and weights.
@@ -203,6 +217,14 @@
 
     means[i] /= weights[i];
     covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
+
+    if (accu(covariances[i]) == 0)
+    {
+      Log::Debug << "Covariance " << i << " sums to zero!  Adding perturbation."
+          << std::endl;
+      covariances[i].diag() += 1e-50;
+      Log::Debug << covariances[i];
+    }
   }
 
   // Finally, normalize weights.




More information about the mlpack-svn mailing list