[mlpack-git] master: Add more useful GMM programs. (09f9c3e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Dec 18 11:43:18 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5ba11bc90223b55eecd5da4cfbe86c8fc40637a5...df229e45a5bd7842fe019e9d49ed32f13beb6aaa

>---------------------------------------------------------------

commit 09f9c3ea9e30e73b10918904c79c37df31a4fe2a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Dec 18 15:53:23 2015 +0000

    Add more useful GMM programs.


>---------------------------------------------------------------

09f9c3ea9e30e73b10918904c79c37df31a4fe2a
 CMakeLists.txt                                  |  10 +-
 src/mlpack/core.hpp                             |  10 +-
 src/mlpack/methods/gmm/CMakeLists.txt           |  23 ++++-
 src/mlpack/methods/gmm/gmm_generate_main.cpp    |  52 +++++++++++
 src/mlpack/methods/gmm/gmm_probability_main.cpp |  45 +++++++++
 src/mlpack/methods/gmm/gmm_train_main.cpp       | 119 ++++++++++++------------
 6 files changed, 184 insertions(+), 75 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0684b4b..3e4c0fc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -334,11 +334,11 @@ if (UNIX)
         WORKING_DIRECTORY
           ${CMAKE_BINARY_DIR}/bin
         DEPENDS
-          allkfn allknn allkrann cf decision_stump det emst fastmks gmm
-          hmm_generate hmm_loglik hmm_train hmm_viterbi kernel_pca kmeans lars
-          linear_regression local_coordinate_coding logistic_regression lsh nbc
-          nca nmf pca perceptron radical range_search softmax_regression
-          sparse_coding
+          allkfn allknn allkrann cf decision_stump det emst fastmks gmm_train
+          gmm_probability gmm_generate hmm_generate hmm_loglik hmm_train
+          hmm_viterbi kernel_pca kmeans lars linear_regression
+          local_coordinate_coding logistic_regression lsh nbc nca nmf pca
+          perceptron radical range_search softmax_regression sparse_coding
         COMMENT "Generating man pages from built executables."
     )
 
diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp
index 5c355a6..76e002e 100644
--- a/src/mlpack/core.hpp
+++ b/src/mlpack/core.hpp
@@ -1,4 +1,4 @@
-/***
+/**
  * @file core.hpp
  *
  * Include all of the base components required to write MLPACK methods, and the
@@ -49,10 +49,10 @@
  *
  * A full list of executables is given below:
  *
- * allkfn, allknn, det, emst, gmm, hmm_train, hmm_loglik, hmm_viterbi,
- * hmm_generate, kernel_pca, kmeans, lars, linear_regression,
- * local_coordinate_coding, logistic_regression, lsh, mvu, nbc, nca, pca,
- * radical, range_search, softmax_regression, sparse_coding
+ * allkfn, allknn, det, emst, gmm_train, gmm_generate, gmm_probability,
+ * hmm_train, hmm_loglik, hmm_viterbi, hmm_generate, kernel_pca, kmeans, lars,
+ * linear_regression, local_coordinate_coding, logistic_regression, lsh, mvu,
+ * nbc, nca, pca, radical, range_search, softmax_regression, sparse_coding
  *
  * @section tutorial Tutorials
  *
diff --git a/src/mlpack/methods/gmm/CMakeLists.txt b/src/mlpack/methods/gmm/CMakeLists.txt
index ce697e4..24d027b 100644
--- a/src/mlpack/methods/gmm/CMakeLists.txt
+++ b/src/mlpack/methods/gmm/CMakeLists.txt
@@ -10,7 +10,6 @@ set(SOURCES
   positive_definite_constraint.hpp
   diagonal_constraint.hpp
   eigenvalue_ratio_constraint.hpp
-  gmm_util.hpp
 )
 
 # Add directory name to sources.
@@ -23,13 +22,27 @@ endforeach()
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
 # main executable, gmm
-add_executable(gmm
-  gmm_main.cpp
+add_executable(gmm_train
+  gmm_train_main.cpp
 )
 
 # link dependencies of gmm
-target_link_libraries(gmm
+target_link_libraries(gmm_train
   mlpack
 )
 
-install(TARGETS gmm RUNTIME DESTINATION bin)
+add_executable(gmm_generate
+  gmm_generate_main.cpp
+)
+target_link_libraries(gmm_generate
+  mlpack
+)
+
+add_executable(gmm_probability
+  gmm_probability_main.cpp
+)
+target_link_libraries(gmm_probability
+  mlpack
+)
+
+install(TARGETS gmm_train gmm_generate gmm_probability RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/gmm/gmm_generate_main.cpp b/src/mlpack/methods/gmm/gmm_generate_main.cpp
new file mode 100644
index 0000000..7fcf2e1
--- /dev/null
+++ b/src/mlpack/methods/gmm/gmm_generate_main.cpp
@@ -0,0 +1,52 @@
+/**
+ * @file gmm_generate_main.cpp
+ * @author Ryan Curtin
+ *
+ * Load a GMM from file, then generate samples from it.
+ */
+#include <mlpack/core.hpp>
+#include "gmm.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::gmm;
+
+PROGRAM_INFO("GMM Sample Generator",
+    "This program is able to generate samples from a pre-trained GMM (use "
+    "gmm_train to train a GMM).  It loads a GMM from the file specified with "
+    "--input_model_file (-m), and generates a number of samples from that "
+    "model; the number of samples is specified by the --samples (-n) parameter."
+    "The output samples are saved in the file specified by --output_file "
+    "(-o).");
+
+PARAM_STRING_REQ("input_model_file", "File containing input GMM model.", "m");
+PARAM_INT_REQ("samples", "Number of samples to generate.", "n");
+
+PARAM_STRING("output_file", "File to save output samples in.", "o",
+    "output.csv");
+
+PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
+
+int main(int argc, char** argv)
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  if (CLI::GetParam<int>("seed") == 0)
+    mlpack::math::RandomSeed(time(NULL));
+  else
+    mlpack::math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+
+  if (CLI::GetParam<int>("samples") < 0)
+    Log::Fatal << "Parameter to --samples must be greater than 0!" << endl;
+
+  GMM gmm;
+  data::Load(CLI::GetParam<string>("input_model_file"), "gmm", gmm, true);
+
+  size_t length = (size_t) CLI::GetParam<int>("samples");
+  Log::Info << "Generating " << length << " samples..." << endl;
+  arma::mat samples(gmm.Dimensionality(), length);
+  for (size_t i = 0; i < length; ++i)
+    samples.col(i) = gmm.Random();
+
+  data::Save(CLI::GetParam<string>("output_file"), samples);
+}
diff --git a/src/mlpack/methods/gmm/gmm_probability_main.cpp b/src/mlpack/methods/gmm/gmm_probability_main.cpp
new file mode 100644
index 0000000..b771aa2
--- /dev/null
+++ b/src/mlpack/methods/gmm/gmm_probability_main.cpp
@@ -0,0 +1,45 @@
+/**
+ * @file gmm_probability_main.cpp
+ * @author Ryan Curtin
+ *
+ * Given a GMM, calculate the probability of points coming from it.
+ */
+#include <mlpack/core.hpp>
+#include "gmm.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::gmm;
+
+PROGRAM_INFO("GMM Probability Calculator",
+    "This program calculates the probability that given points came from a "
+    "given GMM (that is, P(X | gmm)).  The GMM is specified with the "
+    "--input_model_file option, and the points are specified with the "
+    "--input_file option.  The output probabilities are stored in the file "
+    "specified by the --output_file option.");
+
+PARAM_STRING_REQ("input_model_file", "File containing input GMM.", "m");
+PARAM_STRING_REQ("input_file", "File containing points.", "i");
+
+PARAM_STRING("output_file", "File to save calculated probabilities to.", "o",
+    "output.csv");
+
+int main(int argc, char** argv)
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  // Get the GMM and the points.
+  GMM gmm;
+  data::Load(CLI::GetParam<string>("input_model_file"), "gmm", gmm);
+
+  arma::mat dataset;
+  data::Load(CLI::GetParam<string>("input_file"), dataset);
+
+  // Now calculate the probabilities.
+  arma::rowvec probabilities(dataset.n_cols);
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+    probabilities[i] = gmm.Probability(dataset.unsafe_col(i));
+
+  // And save the result.
+  data::Save(CLI::GetParam<string>("output_file"), probabilities);
+}
diff --git a/src/mlpack/methods/gmm/gmm_train_main.cpp b/src/mlpack/methods/gmm/gmm_train_main.cpp
index e1046f9..914fbe5 100644
--- a/src/mlpack/methods/gmm/gmm_train_main.cpp
+++ b/src/mlpack/methods/gmm/gmm_train_main.cpp
@@ -1,6 +1,6 @@
 /**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file gmm_main.cpp
+ * @author Parikshit Ram
+ * @file gmm_train_main.cpp
  *
  * This program trains a mixture of Gaussians on a given data matrix.
  */
@@ -8,7 +8,6 @@
 
 #include "gmm.hpp"
 #include "no_constraint.hpp"
-#include "gmm_util.hpp"
 
 #include <mlpack/methods/kmeans/refined_start.hpp>
 
@@ -21,29 +20,32 @@ using namespace std;
 PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
     "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
     " using the EM algorithm to find the maximum likelihood estimate.  The "
-    "model is saved to an XML file, which contains information about each "
+    "model may be saved to file, which will contain information about each "
     "Gaussian."
     "\n\n"
     "If GMM training fails with an error indicating that a covariance matrix "
-    "could not be inverted, be sure that the 'no_force_positive' flag was not "
-    "specified.  Alternately, adding a small amount of Gaussian noise to the "
-    "entire dataset may help prevent Gaussians with zero variance in a "
-    "particular dimension, which is usually the cause of non-invertible "
-    "covariance matrices."
+    "could not be inverted, make sure that the --no_force_positive flag is not "
+    "specified.  Alternately, adding a small amount of Gaussian noise (using "
+    "the --noise parameter) to the entire dataset may help prevent Gaussians "
+    "with zero variance in a particular dimension, which is usually the cause "
+    "of non-invertible covariance matrices."
     "\n\n"
     "The 'no_force_positive' flag, if set, will avoid the checks after each "
     "iteration of the EM algorithm which ensure that the covariance matrices "
     "are positive definite.  Specifying the flag can cause faster runtime, "
     "but may also cause non-positive definite covariance matrices, which will "
-    "cause the program to crash.");
+    "cause the program to crash."
+    "\n\n"
+    "Optionally, multiple trials may be performed, by specifying the --trials "
+    "option.  The model with greatest log-likelihood will be taken.");
 
+// Parameters for training.
 PARAM_STRING_REQ("input_file", "File containing the data on which the model "
     "will be fit.", "i");
-PARAM_INT("gaussians", "Number of Gaussians in the GMM.", "g", 1);
-PARAM_STRING("output_file", "The file to write the trained GMM parameters "
-    "into.", "o", "gmm.xml");
+PARAM_INT_REQ("gaussians", "Number of Gaussians in the GMM.", "g");
+
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
+PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 1);
 
 // Parameters for EM algorithm.
 PARAM_DOUBLE("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
@@ -65,6 +67,12 @@ PARAM_DOUBLE("percentage", "If using --refined_start, specify the percentage of"
     " the dataset used for each sampling (should be between 0.0 and 1.0).",
     "p", 0.02);
 
+// Parameters for model saving/loading.
+PARAM_STRING("input_model_file", "File containing initial input GMM model.",
+    "m", "");
+PARAM_STRING("output_model_file", "File to save trained GMM model to.", "M",
+    "");
+
 int main(int argc, char* argv[])
 {
   CLI::ParseCommandLine(argc, argv);
@@ -75,10 +83,6 @@ int main(int argc, char* argv[])
   else
     math::RandomSeed((size_t) std::time(NULL));
 
-  arma::mat dataPoints;
-  data::Load(CLI::GetParam<string>("input_file"), dataPoints,
-      true);
-
   const int gaussians = CLI::GetParam<int>("gaussians");
   if (gaussians <= 0)
   {
@@ -86,6 +90,13 @@ int main(int argc, char* argv[])
         "be greater than or equal to 1." << std::endl;
   }
 
+  if (!CLI::HasParam("output_model_file"))
+    Log::Warn << "--output_model_file is not specified, so no model will be "
+        << "saved!" << endl;
+
+  arma::mat dataPoints;
+  data::Load(CLI::GetParam<string>("input_file"), dataPoints, true);
+
   // Do we need to add noise to the dataset?
   if (CLI::HasParam("noise"))
   {
@@ -97,6 +108,20 @@ int main(int argc, char* argv[])
     Timer::Stop("noise_addition");
   }
 
+  // Initialize GMM.
+  GMM gmm(size_t(gaussians), dataPoints.n_rows);
+
+  if (CLI::HasParam("input_model_file"))
+  {
+    data::Load(CLI::GetParam<string>("input_model_file"), "gmm", gmm, true);
+
+    if (gmm.Dimensionality() != dataPoints.n_rows)
+      Log::Fatal << "Given input data (with --input_file) has dimensionality "
+          << dataPoints.n_rows << ", but the initial model (given with "
+          << "--input_model_file) has dimensionality " << gmm.Dimensionality()
+          << "!" << endl;
+  }
+
   // Gather parameters for EMFit object.
   const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
   const double tolerance = CLI::GetParam<double>("tolerance");
@@ -128,34 +153,21 @@ int main(int argc, char* argv[])
     // types.
     if (forcePositive)
     {
-      EMFit<KMeansType> em(maxIterations, tolerance, k);
-
-      GMM<EMFit<KMeansType> > gmm(size_t(gaussians), dataPoints.n_rows, em);
-
       // Compute the parameters of the model using the EM algorithm.
       Timer::Start("em");
-      likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+      EMFit<KMeansType> em(maxIterations, tolerance, k);
+      likelihood = gmm.Train(dataPoints, CLI::GetParam<int>("trials"), false,
+          em);
       Timer::Stop("em");
-
-      // Save results.
-      const string outputFile = CLI::GetParam<string>("output_file");
-      SaveGMM(gmm, outputFile);
     }
     else
     {
-      EMFit<KMeansType, NoConstraint> em(maxIterations, tolerance, k);
-
-      GMM<EMFit<KMeansType, NoConstraint> > gmm(size_t(gaussians),
-          dataPoints.n_rows, em);
-
       // Compute the parameters of the model using the EM algorithm.
       Timer::Start("em");
-      likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+      EMFit<KMeansType, NoConstraint> em(maxIterations, tolerance, k);
+      likelihood = gmm.Train(dataPoints, CLI::GetParam<int>("trials"), false,
+          em);
       Timer::Stop("em");
-
-      // Save results.
-      const string outputFile = CLI::GetParam<string>("output_file");
-      SaveGMM(gmm, outputFile);
     }
   }
   else
@@ -163,39 +175,26 @@ int main(int argc, char* argv[])
     // Depending on the value of forcePositive, we have to use different types.
     if (forcePositive)
     {
-      EMFit<> em(maxIterations, tolerance);
-
-      // Calculate mixture of Gaussians.
-      GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
-
       // Compute the parameters of the model using the EM algorithm.
       Timer::Start("em");
-      likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+      EMFit<> em(maxIterations, tolerance);
+      likelihood = gmm.Train(dataPoints, CLI::GetParam<int>("trials"), false,
+          em);
       Timer::Stop("em");
-
-      // Save results.
-      const string outputFile = CLI::GetParam<string>("output_file");
-      SaveGMM(gmm, outputFile);
     }
     else
     {
-      // Use no constraints on the covariance matrix.
-      EMFit<KMeans<>, NoConstraint> em(maxIterations, tolerance);
-
-      // Calculate mixture of Gaussians.
-      GMM<EMFit<KMeans<>, NoConstraint> > gmm(size_t(gaussians),
-          dataPoints.n_rows, em);
-
       // Compute the parameters of the model using the EM algorithm.
       Timer::Start("em");
-      likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+      EMFit<KMeans<>, NoConstraint> em(maxIterations, tolerance);
+      likelihood = gmm.Train(dataPoints, CLI::GetParam<int>("trials"), false,
+          em);
       Timer::Stop("em");
-
-      // Save results.
-      const string outputFile = CLI::GetParam<string>("output_file");
-      SaveGMM(gmm, outputFile);
     }
   }
 
-  Log::Info << "Log-likelihood of estimate: " << likelihood << ".\n";
+  Log::Info << "Log-likelihood of estimate: " << likelihood << "." << endl;
+
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "gmm", gmm);
 }



More information about the mlpack-git mailing list