[mlpack-svn] r14892 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 11 18:37:15 EDT 2013
Author: rcurtin
Date: 2013-04-11 18:37:15 -0400 (Thu, 11 Apr 2013)
New Revision: 14892
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Add parameters for EM algorithm, and also add parameter which adds random
Gaussian noise.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2013-04-11 22:36:55 UTC (rev 14891)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2013-04-11 22:37:15 UTC (rev 14892)
@@ -10,7 +10,14 @@
"This program takes a parametric estimate of a Gaussian mixture model (GMM)"
" using the EM algorithm to find the maximum likelihood estimate. The "
"model is saved to an XML file, which contains information about each "
- "Gaussian.");
+ "Gaussian."
+ "\n\n"
+ "If GMM training fails with an error indicating that a covariance matrix "
+ "could not be inverted, this is probably remedied either via a larger "
+ "option to the perturbation parameter, or alternately, adding a small "
+ "amount of Gaussian noise to the entire dataset. This helps prevent "
+ "Gaussians with zero variance in a particular dimension, which is usually "
+ "the cause of non-invertible covariance matrices.");
PARAM_STRING_REQ("input_file", "File containing the data on which the model "
"will be fit.", "i");
@@ -20,6 +27,17 @@
PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
+// Parameters for EM algorithm.
+PARAM_DOUBLE("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
+PARAM_DOUBLE("perturbation", "Perturbation to add to zero-valued diagonal "
+ "covariance entries.", "p", 1e-30);
+PARAM_INT("max_iterations", "Maximum number of iterations of EM algorithm "
+ "(passing 0 will run until convergence).", "n", 250);
+
+// Parameters for dataset modification.
+PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
+ "N", 0);
+
using namespace mlpack;
using namespace mlpack::gmm;
using namespace mlpack::util;
@@ -38,15 +56,32 @@
data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
true);
- int gaussians = CLI::GetParam<int>("gaussians");
+ const int gaussians = CLI::GetParam<int>("gaussians");
if (gaussians <= 0)
{
Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
"be greater than or equal to 1." << std::endl;
}
+ // Do we need to add noise to the dataset?
+ if (CLI::HasParam("noise"))
+ {
+ Timer::Start("noise_addition");
+ const double noise = CLI::GetParam<double>("noise");
+ dataPoints += noise * arma::randn(dataPoints.n_rows, dataPoints.n_cols);
+ Log::Info << "Added zero-mean Gaussian noise with variance " << noise
+ << " to dataset." << std::endl;
+ Timer::Stop("noise_addition");
+ }
+
+ // Gather parameters for EMFit object.
+ const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+ const double perturbation = CLI::GetParam<double>("perturbation");
+ EMFit<> em(maxIterations, tolerance, perturbation);
+
// Calculate mixture of Gaussians.
- GMM<> gmm(size_t(gaussians), dataPoints.n_rows);
+ GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
// Compute the parameters of the model using the EM algorithm.
Timer::Start("em");
More information about the mlpack-svn
mailing list