[mlpack-svn] r10532 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 3 19:08:47 EST 2011


Author: rcurtin
Date: 2011-12-03 19:08:47 -0500 (Sat, 03 Dec 2011)
New Revision: 10532

Added:
   mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Add Estimate(const arma::mat&, const arma::vec&) so that this class can work as
a distribution for HMMs.


Modified: mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt	2011-12-03 22:20:22 UTC (rev 10531)
+++ mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt	2011-12-04 00:08:47 UTC (rev 10532)
@@ -4,6 +4,7 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   gmm.hpp
+  gmm_impl.hpp
   gmm.cpp
   phi.hpp
 )

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2011-12-03 22:20:22 UTC (rev 10531)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2011-12-04 00:08:47 UTC (rev 10532)
@@ -35,17 +35,16 @@
   return "0 0";
 }
 
-void GMM::ExpectationMaximization(const arma::mat& data)
+void GMM::Estimate(const arma::mat& data)
 {
   // Create temporary models and set to the right size.
-  std::vector<arma::vec> means_trial(gaussians, arma::vec(dimension));
-  std::vector<arma::mat> covariances_trial(gaussians,
-      arma::mat(dimension, dimension));
-  arma::vec weights_trial(gaussians);
+  std::vector<arma::vec> means_trial;
+  std::vector<arma::mat> covariances_trial;
+  arma::vec weights_trial;
 
   arma::mat cond_prob(data.n_cols, gaussians);
 
-  long double l, l_old, best_l, TINY = 1.0e-4;
+  double l, l_old, best_l, TINY = 1.0e-4;
 
   best_l = -DBL_MAX;
 
@@ -55,49 +54,113 @@
   // as our trained model.
   for (size_t iter = 0; iter < 10; iter++)
   {
-    arma::Col<size_t> assignments;
+    InitialClustering(k, data, means_trial, covariances_trial, weights_trial);
 
-    k.Cluster(data, gaussians, assignments);
+    l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
 
-    // Clear the weights, covariances, and means, before we recalculate them.
-    weights_trial.zeros();
-    for (size_t i = 0; i < gaussians; i++)
-    {
-      means_trial[i].zeros();
-      covariances_trial[i].zeros();
-    }
+    Log::Info << "K-means log-likelihood: " << l << std::endl;
 
-    // From the assignments, generate our means, covariances, and weights.
-    for (size_t i = 0; i < data.n_cols; i++)
+    l_old = -DBL_MAX;
+
+    // Iterate to update the model until no more improvement is found.
+    size_t max_iterations = 300;
+    size_t iteration = 0;
+    while (std::abs(l - l_old) > TINY && iteration < max_iterations)
     {
-      size_t cluster = assignments[i];
+      // Calculate the conditional probabilities of choosing a particular
+      // Gaussian given the data and the present theta value.
+      for (size_t i = 0; i < gaussians; i++)
+      {
+        // Store conditional probabilities into cond_prob vector for each
+        // Gaussian.  First we make an alias of the cond_prob vector.
+        arma::vec cond_prob_alias = cond_prob.unsafe_col(i);
+        phi(data, means_trial[i], covariances_trial[i], cond_prob_alias);
+        cond_prob_alias *= weights_trial[i];
+      }
 
-      // Add this to the relevant mean.
-      means_trial[cluster] += data.col(i);
+      // Normalize row-wise.
+      for (size_t i = 0; i < cond_prob.n_rows; i++)
+        cond_prob.row(i) /= accu(cond_prob.row(i));
 
-      // Add this to the relevant covariance.
-      covariances_trial[cluster] += data.col(i) * trans(data.col(i));
+      // Store the sum of the probability of each state over all the data.
+      arma::vec prob_row_sums = trans(arma::sum(cond_prob, 0 /* columnwise */));
 
-      // Now add one to the weights (we will normalize).
-      weights_trial[cluster]++;
+      // Calculate the new value of the means using the updated conditional
+      // probabilities.
+      for (size_t i = 0; i < gaussians; i++)
+      {
+        means_trial[i] = (data * cond_prob.col(i)) / prob_row_sums[i];
+
+        // Calculate the new value of the covariances using the updated
+        // conditional probabilities and the updated means.
+        arma::mat tmp = data - (means_trial[i] *
+            arma::ones<arma::rowvec>(data.n_cols));
+        arma::mat tmp_b = tmp % (arma::ones<arma::vec>(data.n_rows) *
+            trans(cond_prob.col(i)));
+
+        covariances_trial[i] = (tmp * trans(tmp_b)) / prob_row_sums[i];
+      }
+
+      // Calculate the new values for omega using the updated conditional
+      // probabilities.
+      weights_trial = prob_row_sums / data.n_cols;
+
+      // Update values of l; calculate new log-likelihood.
+      l_old = l;
+      l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
+
+      iteration++;
     }
 
-    // Now normalize the mean and covariance.
-    for (size_t i = 0; i < gaussians; i++)
+    Log::Info << "Likelihood of iteration " << iter << " (total " << iteration
+        << " iterations): " << l << std::endl;
+
+    // The trial model is trained.  Is it better than our existing model?
+    if (l > best_l)
     {
-      covariances_trial[i] -= means_trial[i] * trans(means_trial[i]) /
-          weights_trial[i];
+      best_l = l;
 
-      means_trial[i] /= weights_trial[i];
-
-      covariances_trial[i] /= (weights_trial[i] > 1) ? weights_trial[i] : 1;
+      means = means_trial;
+      covariances = covariances_trial;
+      weights = weights_trial;
     }
+  }
 
-    // Finally, normalize weights.
-    weights_trial /= accu(weights_trial);
+  Log::Info << "Log likelihood value of the estimated model: " << best_l << "."
+      << std::endl;
+  return;
+}
 
-    l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
+/**
+ * Estimate the probability distribution directly from the given observations,
+ * taking into account the probability of each observation actually being from
+ * this distribution.
+ */
+void GMM::Estimate(const arma::mat& observations,
+                   const arma::vec& probabilities)
+{
+  // This will be very similar to Estimate(const arma::mat&), but there will be
+  // minor differences in how we calculate the means, covariances, and weights.
+  std::vector<arma::vec> means_trial;
+  std::vector<arma::mat> covariances_trial;
+  arma::vec weights_trial;
 
+  arma::mat cond_prob(observations.n_cols, gaussians);
+
+  double l, l_old, best_l, TINY = 1.0e-4;
+
+  best_l = -DBL_MAX;
+
+  KMeans<> k; // Default KMeans parameters, for now.
+
+  // We will perform ten trials, and then save the trial with the best result
+  // as our trained model.
+  for (size_t iter = 0; iter < 10; iter++)
+  {
+    InitialClustering(k, observations, means_trial, covariances_trial, weights_trial);
+
+    l = Loglikelihood(observations, means_trial, covariances_trial, weights_trial);
+
     Log::Info << "K-means log-likelihood: " << l << std::endl;
 
     l_old = -DBL_MAX;
@@ -108,13 +171,13 @@
     while (std::abs(l - l_old) > TINY && iteration < max_iterations)
     {
       // Calculate the conditional probabilities of choosing a particular
-      // Gaussian given the data and the present theta value.
+      // Gaussian given the observations and the present theta value.
       for (size_t i = 0; i < gaussians; i++)
       {
         // Store conditional probabilities into cond_prob vector for each
         // Gaussian.  First we make an alias of the cond_prob vector.
         arma::vec cond_prob_alias = cond_prob.unsafe_col(i);
-        phi(data, means_trial[i], covariances_trial[i], cond_prob_alias);
+        phi(observations, means_trial[i], covariances_trial[i], cond_prob_alias);
         cond_prob_alias *= weights_trial[i];
       }
 
@@ -122,32 +185,40 @@
       for (size_t i = 0; i < cond_prob.n_rows; i++)
         cond_prob.row(i) /= accu(cond_prob.row(i));
 
-      // Store the sum of the probability of each state over all the data.
-      arma::vec prob_row_sums = trans(arma::sum(cond_prob, 0 /* columnwise */));
+      // This will store the sum of probabilities of each state over all the
+      // observations.
+      arma::vec prob_row_sums(gaussians);
 
       // Calculate the new value of the means using the updated conditional
       // probabilities.
       for (size_t i = 0; i < gaussians; i++)
       {
-        means_trial[i] = (data * cond_prob.col(i)) / prob_row_sums[i];
+        // Calculate the sum of probabilities of points, which is the
+        // conditional probability of each point being from Gaussian i
+        // multiplied by the probability of the point being from this mixture
+        // model.
+        prob_row_sums[i] = accu(cond_prob.col(i) % probabilities);
 
+        means_trial[i] = (observations * (cond_prob.col(i) % probabilities)) /
+            prob_row_sums[i];
+
         // Calculate the new value of the covariances using the updated
         // conditional probabilities and the updated means.
-        arma::mat tmp = data - (means_trial[i] *
-            arma::ones<arma::rowvec>(data.n_cols));
-        arma::mat tmp_b = tmp % (arma::ones<arma::vec>(data.n_rows) *
-            trans(cond_prob.col(i)));
+        arma::mat tmp = observations - (means_trial[i] *
+            arma::ones<arma::rowvec>(observations.n_cols));
+        arma::mat tmp_b = tmp % (arma::ones<arma::vec>(observations.n_rows) *
+            trans(cond_prob.col(i) % probabilities));
 
         covariances_trial[i] = (tmp * trans(tmp_b)) / prob_row_sums[i];
       }
 
       // Calculate the new values for omega using the updated conditional
       // probabilities.
-      weights_trial = prob_row_sums / data.n_cols;
+      weights_trial = prob_row_sums / accu(probabilities);
 
       // Update values of l; calculate new log-likelihood.
       l_old = l;
-      l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
+      l = Loglikelihood(observations, means_trial, covariances_trial, weights_trial);
 
       iteration++;
     }
@@ -171,12 +242,12 @@
   return;
 }
 
-long double GMM::Loglikelihood(const arma::mat& data,
-                               const std::vector<arma::vec>& means_l,
-                               const std::vector<arma::mat>& covariances_l,
-                               const arma::vec& weights_l) const
+double GMM::Loglikelihood(const arma::mat& data,
+                          const std::vector<arma::vec>& means_l,
+                          const std::vector<arma::mat>& covariances_l,
+                          const arma::vec& weights_l) const
 {
-  long double loglikelihood = 0;
+  double loglikelihood = 0;
 
   arma::vec phis;
   arma::mat likelihoods(gaussians, data.n_cols);

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2011-12-03 22:20:22 UTC (rev 10531)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2011-12-04 00:08:47 UTC (rev 10532)
@@ -27,7 +27,7 @@
  * GMM mog;
  * ArrayList<double> results;
  *
- * mog.Init(number_of_gaussians, dimension);
+ * mog.Init(number_of_gaussians, dimensionality);
  * mog.ExpectationMaximization(data, &results, optim_flag);
  * @endcode
  */
@@ -36,7 +36,7 @@
   //! The number of Gaussians in the model.
   size_t gaussians;
   //! The dimensionality of the model.
-  size_t dimension;
+  size_t dimensionality;
   //! Vector of means; one for each Gaussian.
   std::vector<arma::vec> means;
   //! Vector of covariances; one for each Gaussian.
@@ -46,23 +46,47 @@
 
  public:
   /**
+   * Create an empty Gaussian Mixture Model, with zero gaussians.
+   */
+  GMM() :
+    gaussians(0),
+    dimensionality(0) { /* nothing to do */ }
+
+  /**
    * Create a GMM with the given number of Gaussians, each of which have the
    * specified dimensionality.
    *
    * @param gaussians Number of Gaussians in this GMM.
-   * @param dimension Dimensionality of each Gaussian.
+   * @param dimensionality Dimensionality of each Gaussian.
    */
-  GMM(size_t gaussians, size_t dimension) :
+  GMM(size_t gaussians, size_t dimensionality) :
       gaussians(gaussians),
-      dimension(dimension),
-      means(gaussians),
-      covariances(gaussians) { /* nothing to do */ }
+      dimensionality(dimensionality),
+      means(gaussians, arma::vec(dimensionality)),
+      covariances(gaussians, arma::mat(dimensionality, dimensionality)),
+      weights(gaussians) { /* nothing to do */ }
 
+  /**
+   * Create a GMM with the given means, covariances, and weights.
+   *
+   * @param means Means of the model.
+   * @param covariances Covariances of the model.
+   * @param weights Weights of the model.
+   */
+  GMM(const std::vector<arma::vec>& means,
+      const std::vector<arma::mat>& covariances,
+      const arma::vec& weights) :
+      gaussians(means.size()),
+      dimensionality((means.size() > 0) ? means[0].n_elem : 0),
+      means(means),
+      covariances(covariances),
+      weights(weights) { /* nothing to do */ }
+
   //! Return the number of gaussians in the model.
-  const size_t Gaussians() { return gaussians; }
+  size_t Gaussians() const { return gaussians; }
 
   //! Return the dimensionality of the model.
-  const size_t Dimension() { return dimension; }
+  size_t Dimensionality() const { return dimensionality; }
 
   //! Return a const reference to the vector of means (mu).
   const std::vector<arma::vec>& Means() const { return means; }
@@ -79,7 +103,6 @@
   //! Return a reference to the a priori weights of each Gaussian.
   arma::vec& Weights() { return weights; }
 
-
   /**
    * Return the probability that the given observation came from this
    * distribution.
@@ -97,25 +120,63 @@
   arma::vec Random() const;
 
   /**
-   * This function estimates the parameters of the Gaussian Mixture Model using
-   * the Maximum Likelihood estimator, via the EM (Expectation Maximization)
-   * algorithm.
+   * Estimate the probability distribution directly from the given observations,
+   * using the EM algorithm to obtain the Maximum Likelihood parameter.
+   *
+   * @param observations Observations of the model.
    */
-  void ExpectationMaximization(const arma::mat& data_points);
+  void Estimate(const arma::mat& observations);
 
   /**
-   * This function computes the loglikelihood of model.
-   * This function is used by the 'ExpectationMaximization'
-   * function.
+   * Estimate the probability distribution directly from the given observations,
+   * taking into account the probability of each observation actually being from
+   * this distribution.
    *
+   * @param observations Observations of the model.
+   * @param probability Probability of each observation.
    */
-  long double Loglikelihood(const arma::mat& data_points,
-                            const std::vector<arma::vec>& means,
-                            const std::vector<arma::mat>& covars,
-                            const arma::vec& weights) const;
+  void Estimate(const arma::mat& observations,
+                const arma::vec& probabilities);
+
+ private:
+  /**
+   * This function computes the loglikelihood of the given model.  This function
+   * is used by GMM::Estimate().
+   *
+   * @param dataPoints Observations to calculate the likelihood for.
+   * @param means Means of the given mixture model.
+   * @param covars Covariances of the given mixture model.
+   * @param weights Weights of the given mixture model.
+   */
+  double Loglikelihood(const arma::mat& dataPoints,
+                       const std::vector<arma::vec>& means,
+                       const std::vector<arma::mat>& covars,
+                       const arma::vec& weights) const;
+
+  /**
+   * This function uses the given clustering class and initializes means,
+   * covariances, and weights into the passed objects based on the assignments
+   * of the clustering class.
+   *
+   * @param clusterer Initialized clustering class (must implement void
+   *      Cluster(const arma::mat&, arma::Col<size_t>&)
+   * @param data Dataset to perform clustering on.
+   * @param means Vector to store means in.
+   * @param covars Vector to store covariances in.
+   * @param weights Vector to store weights in.
+   */
+  template<typename ClusteringType>
+  void InitialClustering(const ClusteringType& clusterer,
+                         const arma::mat& data,
+                         std::vector<arma::vec>& means,
+                         std::vector<arma::mat>& covars,
+                         arma::vec& weights) const;
 };
 
 }; // namespace gmm
 }; // namespace mlpack
 
+// Include implementation.
+#include "gmm_impl.hpp"
+
 #endif

Added: mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	2011-12-04 00:08:47 UTC (rev 10532)
@@ -0,0 +1,74 @@
+/**
+ * @file gmm_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of template-based GMM methods.
+ */
+#ifndef __MLPACK_METHODS_GMM_GMM_IMPL_HPP
+#define __MLPACK_METHODS_GMM_GMM_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "gmm.hpp"
+
+namespace mlpack {
+namespace gmm {
+
+template<typename ClusteringType>
+void GMM::InitialClustering(const ClusteringType& clusterer,
+                            const arma::mat& data,
+                            std::vector<arma::vec>& meansOut,
+                            std::vector<arma::mat>& covarsOut,
+                            arma::vec& weightsOut) const
+{
+  meansOut.resize(gaussians, arma::vec(dimensionality));
+  covarsOut.resize(gaussians, arma::mat(dimensionality, dimensionality));
+  weightsOut.set_size(gaussians);
+
+  // Assignments from clustering.
+  arma::Col<size_t> assignments;
+
+  // Run clustering algorithm.
+  clusterer.Cluster(data, gaussians, assignments);
+
+  // Now calculate the means, covariances, and weights.
+  weightsOut.zeros();
+  for (size_t i = 0; i < gaussians; i++)
+  {
+    meansOut[i].zeros();
+    covarsOut[i].zeros();
+  }
+
+  // From the assignments, generate our means, covariances, and weights.
+  for (size_t i = 0; i < data.n_cols; i++)
+  {
+    size_t cluster = assignments[i];
+
+    // Add this to the relevant mean.
+    meansOut[cluster] += data.col(i);
+
+    // Add this to the relevant covariance.
+    covarsOut[cluster] += data.col(i) * trans(data.col(i));
+
+    // Now add one to the weights (we will normalize).
+    weightsOut[cluster]++;
+  }
+
+  // Now normalize the mean and covariance.
+  for (size_t i = 0; i < gaussians; i++)
+  {
+    covarsOut[i] -= meansOut[i] * trans(meansOut[i]) /
+        weightsOut[i];
+
+    meansOut[i] /= weightsOut[i];
+
+    covarsOut[i] /= (weightsOut[i] > 1) ? weightsOut[i] : 1;
+  }
+
+  // Finally, normalize weights.
+  weightsOut /= accu(weightsOut);
+}
+
+}; // namespace gmm
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2011-12-03 22:20:22 UTC (rev 10531)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2011-12-04 00:08:47 UTC (rev 10532)
@@ -29,7 +29,7 @@
 
   ////// Computing the parameters of the model using the EM algorithm //////
   Timers::StartTimer("gmm/em");
-  gmm.ExpectationMaximization(data_points);
+  gmm.Estimate(data_points);
   Timers::StopTimer("gmm/em");
 
   ////// OUTPUT RESULTS //////




More information about the mlpack-svn mailing list