[mlpack-svn] r14892 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 11 18:37:15 EDT 2013


Author: rcurtin
Date: 2013-04-11 18:37:15 -0400 (Thu, 11 Apr 2013)
New Revision: 14892

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Add parameters for EM algorithm, and also add parameter which adds random
Gaussian noise.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2013-04-11 22:36:55 UTC (rev 14891)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2013-04-11 22:37:15 UTC (rev 14892)
@@ -10,7 +10,14 @@
     "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
     " using the EM algorithm to find the maximum likelihood estimate.  The "
     "model is saved to an XML file, which contains information about each "
-    "Gaussian.");
+    "Gaussian."
+    "\n\n"
+    "If GMM training fails with an error indicating that a covariance matrix "
+    "could not be inverted, this is probably remedied either via a larger "
+    "option to the perturbation parameter, or alternately, adding a small "
+    "amount of Gaussian noise to the entire dataset.  This helps prevent "
+    "Gaussians with zero variance in a particular dimension, which is usually "
+    "the cause of non-invertible covariance matrices.");
 
 PARAM_STRING_REQ("input_file", "File containing the data on which the model "
     "will be fit.", "i");
@@ -20,6 +27,17 @@
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
 
+// Parameters for EM algorithm.
+PARAM_DOUBLE("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
+PARAM_DOUBLE("perturbation", "Perturbation to add to zero-valued diagonal "
+    "covariance entries.", "p", 1e-30);
+PARAM_INT("max_iterations", "Maximum number of iterations of EM algorithm "
+    "(passing 0 will run until convergence).", "n", 250);
+
+// Parameters for dataset modification.
+PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
+    "N", 0);
+
 using namespace mlpack;
 using namespace mlpack::gmm;
 using namespace mlpack::util;
@@ -38,15 +56,32 @@
   data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
       true);
 
-  int gaussians = CLI::GetParam<int>("gaussians");
+  const int gaussians = CLI::GetParam<int>("gaussians");
   if (gaussians <= 0)
   {
     Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
         "be greater than or equal to 1." << std::endl;
   }
 
+  // Do we need to add noise to the dataset?
+  if (CLI::HasParam("noise"))
+  {
+    Timer::Start("noise_addition");
+    const double noise = CLI::GetParam<double>("noise");
+    dataPoints += noise * arma::randn(dataPoints.n_rows, dataPoints.n_cols);
+    Log::Info << "Added zero-mean Gaussian noise with variance " << noise
+        << " to dataset." << std::endl;
+    Timer::Stop("noise_addition");
+  }
+
+  // Gather parameters for EMFit object.
+  const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+  const double tolerance = CLI::GetParam<double>("tolerance");
+  const double perturbation = CLI::GetParam<double>("perturbation");
+  EMFit<> em(maxIterations, tolerance, perturbation);
+
   // Calculate mixture of Gaussians.
-  GMM<> gmm(size_t(gaussians), dataPoints.n_rows);
+  GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
 
   // Compute the parameters of the model using the EM algorithm.
   Timer::Start("em");




More information about the mlpack-svn mailing list