[mlpack-svn] r10776 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 08:06:45 EST 2011
Author: rcurtin
Date: 2011-12-14 08:06:45 -0500 (Wed, 14 Dec 2011)
New Revision: 10776
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Allow a GMM to be saved and clean up documentation.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2011-12-14 13:04:41 UTC (rev 10775)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2011-12-14 13:06:45 UTC (rev 10776)
@@ -8,13 +8,15 @@
PROGRAM_INFO("GMM",
"This program takes a parametric estimate of a Gaussian mixture model (GMM)"
- " using the EM algorithm to find the maximum likelihood estimate.");
+ " using the EM algorithm to find the maximum likelihood estimate. The "
+ "model is saved to an XML file, which contains information about each "
+ "Gaussian.");
-PARAM_STRING_REQ("data", "A file containing the data on which the model has to "
- "be fit.", "D");
-PARAM_INT("gaussians", "g", "G", 1);
+PARAM_STRING_REQ("input_file", "File containing the data on which the model "
+ "will be fit.", "i");
+PARAM_INT("gaussians", "Number of Gaussians in the GMM", "g", 1);
PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
- "(as XML).", "gmm.xml");
+ "(as XML).", "o", "gmm.xml");
using namespace mlpack;
using namespace mlpack::gmm;
@@ -22,21 +24,43 @@
int main(int argc, char* argv[]) {
CLI::ParseCommandLine(argc, argv);
- ////// READING PARAMETERS AND LOADING DATA //////
- arma::mat data_points;
- data::Load(CLI::GetParam<std::string>("data").c_str(), data_points, true);
+ // Check parameters and load data.
+ arma::mat dataPoints;
+ data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
+ true);
- ////// MIXTURE OF GAUSSIANS USING EM //////
- GMM gmm(CLI::GetParam<int>("gaussians"), data_points.n_rows);
+ int gaussians = CLI::GetParam<int>("gaussians");
+ if (gaussians <= 0)
+ {
+ Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
+ "be greater than or equal to 1." << std::endl;
+ }
+ // Calculate mixture of Gaussians.
+ GMM gmm(size_t(gaussians), dataPoints.n_rows);
+
////// Computing the parameters of the model using the EM algorithm //////
Timer::Start("em");
- gmm.Estimate(data_points);
+ gmm.Estimate(dataPoints);
Timer::Stop("em");
////// OUTPUT RESULTS //////
-
+ SaveRestoreUtility save;
+ save.SaveParameter(gmm.Gaussians(), "gaussians");
+ save.SaveParameter(gmm.Dimensionality(), "dimensionality");
+ save.SaveParameter(trans(gmm.Weights()), "weights");
+ for (size_t i = 0; i < gmm.Gaussians(); ++i)
+ {
+ // Generate names for the XML nodes.
+ std::stringstream o;
+ o << i;
+ std::string meanName = "mean" + o.str();
+ std::string covName = "covariance" + o.str();
- // We need a better solution for this. So, currently, we do nothing.
- // XML is probably the right tool for the job.
+ // Now save them.
+ save.SaveParameter(trans(gmm.Means()[0]), meanName);
+ save.SaveParameter(gmm.Covariances()[0], covName);
+ }
+
+ save.WriteFile(CLI::GetParam<std::string>("output_file").c_str());
}
More information about the mlpack-svn
mailing list