[mlpack-svn] r10776 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 08:06:45 EST 2011


Author: rcurtin
Date: 2011-12-14 08:06:45 -0500 (Wed, 14 Dec 2011)
New Revision: 10776

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Allow a GMM to be saved and clean up documentation.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2011-12-14 13:04:41 UTC (rev 10775)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2011-12-14 13:06:45 UTC (rev 10776)
@@ -8,13 +8,15 @@
 
 PROGRAM_INFO("GMM",
     "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
-    " using the EM algorithm to find the maximum likelihood estimate.");
+    " using the EM algorithm to find the maximum likelihood estimate.  The "
+    "model is saved to an XML file, which contains information about each "
+    "Gaussian.");
 
-PARAM_STRING_REQ("data", "A file containing the data on which the model has to "
-    "be fit.", "D");
-PARAM_INT("gaussians", "g", "G", 1);
+PARAM_STRING_REQ("input_file", "File containing the data on which the model "
+    "will be fit.", "i");
+PARAM_INT("gaussians", "Number of Gaussians in the GMM", "g", 1);
 PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
-    "(as XML).", "gmm.xml");
+    "(as XML).", "o", "gmm.xml");
 
 using namespace mlpack;
 using namespace mlpack::gmm;
@@ -22,21 +24,43 @@
 int main(int argc, char* argv[]) {
   CLI::ParseCommandLine(argc, argv);
 
-  ////// READING PARAMETERS AND LOADING DATA //////
-  arma::mat data_points;
-  data::Load(CLI::GetParam<std::string>("data").c_str(), data_points, true);
+  // Check parameters and load data.
+  arma::mat dataPoints;
+  data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
+      true);
 
-  ////// MIXTURE OF GAUSSIANS USING EM //////
-  GMM gmm(CLI::GetParam<int>("gaussians"), data_points.n_rows);
+  int gaussians = CLI::GetParam<int>("gaussians");
+  if (gaussians <= 0)
+  {
+    Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
+        "be greater than or equal to 1." << std::endl;
+  }
 
+  // Calculate mixture of Gaussians.
+  GMM gmm(size_t(gaussians), dataPoints.n_rows);
+
   ////// Computing the parameters of the model using the EM algorithm //////
   Timer::Start("em");
-  gmm.Estimate(data_points);
+  gmm.Estimate(dataPoints);
   Timer::Stop("em");
 
   ////// OUTPUT RESULTS //////
-  
+  SaveRestoreUtility save;
+  save.SaveParameter(gmm.Gaussians(), "gaussians");
+  save.SaveParameter(gmm.Dimensionality(), "dimensionality");
+  save.SaveParameter(trans(gmm.Weights()), "weights");
+  for (size_t i = 0; i < gmm.Gaussians(); ++i)
+  {
+    // Generate names for the XML nodes.
+    std::stringstream o;
+    o << i;
+    std::string meanName = "mean" + o.str();
+    std::string covName = "covariance" + o.str();
 
-  // We need a better solution for this.  So, currently, we do nothing.
-  // XML is probably the right tool for the job.
+    // Now save them.
+    save.SaveParameter(trans(gmm.Means()[0]), meanName);
+    save.SaveParameter(gmm.Covariances()[0], covName);
+  }
+
+  save.WriteFile(CLI::GetParam<std::string>("output_file").c_str());
 }




More information about the mlpack-svn mailing list