[mlpack-svn] r14905 - mlpack/trunk/src/mlpack/core/dists

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 15 17:06:58 EDT 2013


Author: rcurtin
Date: 2013-04-15 17:06:58 -0400 (Mon, 15 Apr 2013)
New Revision: 14905

Modified:
   mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.cpp
Log:
Use the same fix from GMMs to ensure that covariance matrices are positive
definite.


Modified: mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.cpp	2013-04-15 09:02:02 UTC (rev 14904)
+++ mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.cpp	2013-04-15 21:06:58 UTC (rev 14905)
@@ -52,14 +52,17 @@
   // that it is the unbiased estimator.
   covariance /= (observations.n_cols - 1);
 
-  // Ensure that there are no zeros on the diagonal.
-  for (size_t d = 0; d < covariance.n_rows; ++d)
+  // Ensure that the covariance is positive definite.
+  if (det(covariance) <= 1e-50)
   {
-    if (covariance(d, d) == 0.0)
+    Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
+        << "positive definite. Adding perturbation." << std::endl;
+
+    double perturbation = 1e-30;
+    while (det(covariance) <= 1e-50)
     {
-      Log::Debug << "GaussianDistribution::Estimate(): covariance diagonal "
-          << "element " << d << " is 0; adding perturbation." << std::endl;
-      covariance(d, d) = 1e-50;
+      covariance.diag() += perturbation;
+      perturbation *= 10; // Slow, but we don't want to add too much.
     }
   }
 }
@@ -115,14 +118,17 @@
   // This is probably biased, but I don't know how to unbias it.
   covariance /= sumProb;
 
-  // Ensure that there are no zeros on the diagonal.
-  for (size_t d = 0; d < covariance.n_rows; ++d)
+  // Ensure that the covariance is positive definite.
+  if (det(covariance) <= 1e-50)
   {
-    if (covariance(d, d) == 0.0)
+    Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
+        << "positive definite. Adding perturbation." << std::endl;
+
+    double perturbation = 1e-30;
+    while (det(covariance) <= 1e-50)
     {
-      Log::Debug << "GaussianDistribution::Estimate(): covariance diagonal "
-          << "element " << d << " is 0; adding perturbation." << std::endl;
-      covariance(d, d) = 1e-50;
+      covariance.diag() += perturbation;
+      perturbation *= 10; // Slow, but we don't want to add too much.
     }
   }
 }




More information about the mlpack-svn mailing list