[mlpack-svn] r12283 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Apr 10 14:47:01 EDT 2012


Author: rcurtin
Date: 2012-04-10 14:47:01 -0400 (Tue, 10 Apr 2012)
New Revision: 12283

Added:
   mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
   mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Refactor GMMs so that you can use any clustering method you like.  The default
is "EMFit", which is the EM algorithm (exactly how it performed before).


Modified: mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt	2012-04-10 18:07:37 UTC (rev 12282)
+++ mlpack/trunk/src/mlpack/methods/gmm/CMakeLists.txt	2012-04-10 18:47:01 UTC (rev 12283)
@@ -5,8 +5,9 @@
 set(SOURCES
   gmm.hpp
   gmm_impl.hpp
-  gmm.cpp
   phi.hpp
+  em_fit.hpp
+  em_fit_impl.hpp
 )
 
 # Add directory name to sources.

Added: mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -0,0 +1,116 @@
+/**
+ * @file em_fit.hpp
+ * @author Ryan Curtin
+ *
+ * Utility class to fit a GMM using the EM algorithm.  Used by
+ * GMM::Estimate<>().
+ */
+#ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
+#define __MLPACK_METHODS_GMM_EM_FIT_HPP
+
+#include <mlpack/core.hpp>
+
+// Default clustering mechanism.
+#include <mlpack/methods/kmeans/kmeans.hpp>
+
+namespace mlpack {
+namespace gmm {
+
+/**
+ * This class contains methods which can fit a GMM to observations using the EM
+ * algorithm.  It requires an initial clustering mechanism, which is by default
+ * the KMeans algorithm.  The clustering mechanism must implement the following
+ * method:
+ *
+ *  - void Cluster(const arma::mat& observations,
+ *                 const size_t clusters,
+ *                 arma::Col<size_t>& assignments);
+ *
+ * This method should create 'clusters' clusters, and return the assignment of
+ * each point to a cluster.
+ */
+template<typename InitialClusteringType = kmeans::KMeans<> >
+class EMFit
+{
+ public:
+  /**
+   * Construct the EMFit object, optionally passing an InitialClusteringType
+   * object (just in case it needs to store state).
+   */
+  EMFit(InitialClusteringType clusterer = InitialClusteringType()) :
+      clusterer(clusterer) { /* Nothing to do. */ }
+
+  /**
+   * Fit the observations to a Gaussian mixture model (GMM) using the EM
+   * algorithm.  The size of the vectors (indicating the number of components)
+   * must already be set.
+   *
+   * @param observations List of observations to train on.
+   * @param means Vector to store trained means in.
+   * @param covariances Vector to store trained covariances in.
+   * @param weights Vector to store a priori weights in.
+   */
+  void Estimate(const arma::mat& observations,
+                std::vector<arma::vec>& means,
+                std::vector<arma::mat>& covariances,
+                arma::vec& weights);
+
+  /**
+   * Fit the observations to a Gaussian mixture model (GMM) using the EM
+   * algorithm, taking into account the probabilities of each point being from
+   * this mixture.  The size of the vectors (indicating the number of
+   * components) must already be set.
+   *
+   * @param observations List of observations to train on.
+   * @param probabilities Probability of each point being from this model.
+   * @param means Vector to store trained means in.
+   * @param covariances Vector to store trained covariances in.
+   * @param weights Vector to store a priori weights in.
+   */
+  void Estimate(const arma::mat& observations,
+                const arma::vec& probabilities,
+                std::vector<arma::vec>& means,
+                std::vector<arma::mat>& covariances,
+                arma::vec& weights);
+
+ private:
+  /**
+   * Run the clusterer, and then turn the cluster assignments into Gaussians.
+   * This is a helper function for both overloads of Estimate().  The vectors
+   * must be already set to the number of clusters.
+   *
+   * @param observations List of observations.
+   * @param means Vector to store means in.
+   * @param covariances Vector to store covariances in.
+   * @param weights Vector to store a priori weights in.
+   */
+  void InitialClustering(const arma::mat& observations,
+                         std::vector<arma::vec>& means,
+                         std::vector<arma::mat>& covariances,
+                         arma::vec& weights);
+
+  /**
+   * Calculate the log-likelihood of a model.  Yes, this is reimplemented in the
+   * GMM code.  Intuition suggests that the log-likelihood is not the best way
+   * to determine if the EM algorithm has converged.
+   *
+   * @param data Data matrix.
+   * @param means Vector of means.
+   * @param covariances Vector of covariance matrices.
+   * @param weights Vector of a priori weights.
+   */
+  double LogLikelihood(const arma::mat& data,
+                       const std::vector<arma::vec>& means,
+                       const std::vector<arma::mat>& covariances,
+                       const arma::vec& weights) const;
+
+  InitialClusteringType clusterer;
+};
+
+}; // namespace gmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "em_fit_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -0,0 +1,238 @@
+/**
+ * @file em_fit_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of EM algorithm for fitting GMMs.
+ */
+#ifndef __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
+#define __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "em_fit.hpp"
+
+// Definition of phi().
+#include "phi.hpp"
+
+namespace mlpack {
+namespace gmm {
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
+                                            std::vector<arma::vec>& means,
+                                            std::vector<arma::mat>& covariances,
+                                            arma::vec& weights)
+{
+  InitialClustering(observations, means, covariances, weights);
+
+  double l = LogLikelihood(observations, means, covariances, weights);
+
+  Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
+      << l << std::endl;
+
+  double lOld = -DBL_MAX;
+  arma::mat condProb(observations.n_cols, means.size());
+
+  // Iterate to update the model until no more improvement is found.
+  size_t maxIterations = 300;
+  size_t iteration = 0;
+  while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+  {
+    // Calculate the conditional probabilities of choosing a particular
+    // Gaussian given the observations and the present theta value.
+    for (size_t i = 0; i < means.size(); i++)
+    {
+      // Store conditional probabilities into condProb vector for each
+      // Gaussian.  First we make an alias of the condProb vector.
+      arma::vec condProbAlias = condProb.unsafe_col(i);
+      phi(observations, means[i], covariances[i], condProbAlias);
+      condProbAlias *= weights[i];
+    }
+
+    // Normalize row-wise.
+    for (size_t i = 0; i < condProb.n_rows; i++)
+      condProb.row(i) /= accu(condProb.row(i));
+
+    // Store the sum of the probability of each state over all the observations.
+    arma::vec probRowSums = trans(arma::sum(condProb, 0 /* columnwise */));
+
+    // Calculate the new value of the means using the updated conditional
+    // probabilities.
+    for (size_t i = 0; i < means.size(); i++)
+    {
+      means[i] = (observations * condProb.col(i)) / probRowSums[i];
+
+      // Calculate the new value of the covariances using the updated
+      // conditional probabilities and the updated means.
+      arma::mat tmp = observations - (means[i] *
+          arma::ones<arma::rowvec>(observations.n_cols));
+      arma::mat tmp_b = tmp % (arma::ones<arma::vec>(observations.n_rows) *
+          trans(condProb.col(i)));
+
+      covariances[i] = (tmp * trans(tmp_b)) / probRowSums[i];
+    }
+
+    // Calculate the new values for omega using the updated conditional
+    // probabilities.
+    weights = probRowSums / observations.n_cols;
+
+    // Update values of l; calculate new log-likelihood.
+    lOld = l;
+    l = LogLikelihood(observations, means, covariances, weights);
+
+    iteration++;
+  }
+}
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
+                                            const arma::vec& probabilities,
+                                            std::vector<arma::vec>& means,
+                                            std::vector<arma::mat>& covariances,
+                                            arma::vec& weights)
+{
+  InitialClustering(observations, means, covariances, weights);
+
+  double l = LogLikelihood(observations, means, covariances, weights);
+
+  Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
+      << l << std::endl;
+
+  double lOld = -DBL_MAX;
+  arma::mat condProb(observations.n_cols, means.size());
+
+  // Iterate to update the model until no more improvement is found.
+  size_t maxIterations = 300;
+  size_t iteration = 0;
+  while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+  {
+    // Calculate the conditional probabilities of choosing a particular
+    // Gaussian given the observations and the present theta value.
+    for (size_t i = 0; i < means.size(); i++)
+    {
+      // Store conditional probabilities into condProb vector for each
+      // Gaussian.  First we make an alias of the condProb vector.
+      arma::vec condProbAlias = condProb.unsafe_col(i);
+      phi(observations, means[i], covariances[i], condProbAlias);
+      condProbAlias *= weights[i];
+    }
+
+    // Normalize row-wise.
+    for (size_t i = 0; i < condProb.n_rows; i++)
+      condProb.row(i) /= accu(condProb.row(i));
+
+    // This will store the sum of probabilities of each state over all the
+    // observations.
+    arma::vec probRowSums(means.size());
+
+    // Calculate the new value of the means using the updated conditional
+    // probabilities.
+    for (size_t i = 0; i < means.size(); i++)
+    {
+      // Calculate the sum of probabilities of points, which is the
+      // conditional probability of each point being from Gaussian i
+      // multiplied by the probability of the point being from this mixture
+      // model.
+      probRowSums[i] = accu(condProb.col(i) % probabilities);
+
+      means[i] = (observations * (condProb.col(i) % probabilities)) /
+        probRowSums[i];
+
+      // Calculate the new value of the covariances using the updated
+      // conditional probabilities and the updated means.
+      arma::mat tmp = observations - (means[i] *
+          arma::ones<arma::rowvec>(observations.n_cols));
+      arma::mat tmp_b = tmp % (arma::ones<arma::vec>(observations.n_rows) *
+          trans(condProb.col(i) % probabilities));
+
+      covariances[i] = (tmp * trans(tmp_b)) / probRowSums[i];
+    }
+
+    // Calculate the new values for omega using the updated conditional
+    // probabilities.
+    weights = probRowSums / accu(probabilities);
+
+    // Update values of l; calculate new log-likelihood.
+    lOld = l;
+    l = LogLikelihood(observations, means, covariances, weights);
+
+    iteration++;
+  }
+}
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::InitialClustering(
+    const arma::mat& observations,
+    std::vector<arma::vec>& means,
+    std::vector<arma::mat>& covariances,
+    arma::vec& weights)
+{
+  // Assignments from clustering.
+  arma::Col<size_t> assignments;
+
+  // Run clustering algorithm.
+  clusterer.Cluster(observations, means.size(), assignments);
+
+  // Now calculate the means, covariances, and weights.
+  weights.zeros();
+  for (size_t i = 0; i < means.size(); ++i)
+  {
+    means[i].zeros();
+    covariances[i].zeros();
+  }
+
+  // From the assignments, generate our means, covariances, and weights.
+  for (size_t i = 0; i < observations.n_cols; ++i)
+  {
+    size_t cluster = assignments[i];
+
+    // Add this to the relevant mean.
+    means[cluster] += observations.col(i);
+
+    // Add this to the relevant covariance.
+    covariances[cluster] += observations.col(i) * trans(observations.col(i));
+
+    // Now add one to the weights (we will normalize).
+    weights[cluster]++;
+  }
+
+  // Now normalize the mean and covariance.
+  for (size_t i = 0; i < means.size(); ++i)
+  {
+    covariances[i] -= means[i] * trans(means[i]) / weights[i];
+
+    means[i] /= weights[i];
+    covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
+  }
+
+  // Finally, normalize weights.
+  weights /= accu(weights);
+}
+
+template<typename InitialClusteringType>
+double EMFit<InitialClusteringType>::LogLikelihood(
+    const arma::mat& observations,
+    const std::vector<arma::vec>& means,
+    const std::vector<arma::mat>& covariances,
+    const arma::vec& weights) const
+{
+  double logLikelihood = 0;
+
+  arma::vec phis;
+  arma::mat likelihoods(means.size(), observations.n_cols);
+  for (size_t i = 0; i < means.size(); ++i)
+  {
+    phi(observations, means[i], covariances[i], phis);
+    likelihoods.row(i) = weights(i) * trans(phis);
+  }
+
+  // Now sum over every point.
+  for (size_t j = 0; j < observations.n_cols; ++j)
+    logLikelihood += log(accu(likelihoods.col(j)));
+
+  return logLikelihood;
+}
+
+}; // namespace gmm
+}; // namespace mlpack
+
+#endif

Deleted: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2012-04-10 18:07:37 UTC (rev 12282)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -1,327 +0,0 @@
-/**
- * @file gmm.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @author Ryan Curtin
- *
- * Implementation for the loglikelihood function, the EM algorithm
- * and also computes the K-means for getting an initial point
- */
-#include "gmm.hpp"
-#include "phi.hpp"
-
-#include <mlpack/methods/kmeans/kmeans.hpp>
-
-using namespace mlpack;
-using namespace mlpack::gmm;
-using namespace mlpack::kmeans;
-
-/**
- * Return the probability of the given observation being from this GMM.
- */
-double GMM::Probability(const arma::vec& observation) const
-{
-  // Sum the probability for each Gaussian in our mixture (and we have to
-  // multiply by the prior for each Gaussian too).
-  double sum = 0;
-  for (size_t i = 0; i < gaussians; i++)
-    sum += weights[i] * phi(observation, means[i], covariances[i]);
-
-  return sum;
-}
-
-/**
- * Return the probability of the given observation being from the given
- * component in the mixture.
- */
-double GMM::Probability(const arma::vec& observation,
-                        const size_t component) const
-{
-  // We are only considering one Gaussian component -- so we only need to call
-  // phi() once.  We do consider the prior probability!
-  return weights[component] *
-      phi(observation, means[component], covariances[component]);
-}
-
-/**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- */
-arma::vec GMM::Random() const
-{
-  // Determine which Gaussian it will be coming from.
-  double gaussRand = math::Random();
-  size_t gaussian;
-
-  double sumProb = 0;
-  for (size_t g = 0; g < gaussians; g++)
-  {
-    sumProb += weights(g);
-    if (gaussRand <= sumProb)
-    {
-      gaussian = g;
-      break;
-    }
-  }
-
-  return trans(chol(covariances[gaussian])) *
-      arma::randn<arma::vec>(dimensionality) + means[gaussian];
-}
-
-void GMM::Estimate(const arma::mat& data)
-{
-  // Create temporary models and set to the right size.
-  std::vector<arma::vec> meansTrial;
-  std::vector<arma::mat> covariancesTrial;
-  arma::vec weightsTrial;
-
-  arma::mat condProb(data.n_cols, gaussians);
-
-  double l, lOld, bestL, TINY = 1.0e-4;
-
-  bestL = -DBL_MAX;
-
-  KMeans<> k; // Default KMeans parameters, for now.
-
-  // We will perform ten trials, and then save the trial with the best result
-  // as our trained model.
-  for (size_t iter = 0; iter < 10; iter++)
-  {
-    InitialClustering(k, data, meansTrial, covariancesTrial, weightsTrial);
-
-    l = Loglikelihood(data, meansTrial, covariancesTrial, weightsTrial);
-
-    Log::Info << "K-means log-likelihood: " << l << std::endl;
-
-    lOld = -DBL_MAX;
-
-    // Iterate to update the model until no more improvement is found.
-    size_t maxIterations = 300;
-    size_t iteration = 0;
-    while (std::abs(l - lOld) > TINY && iteration < maxIterations)
-    {
-      // Calculate the conditional probabilities of choosing a particular
-      // Gaussian given the data and the present theta value.
-      for (size_t i = 0; i < gaussians; i++)
-      {
-        // Store conditional probabilities into condProb vector for each
-        // Gaussian.  First we make an alias of the condProb vector.
-        arma::vec condProbAlias = condProb.unsafe_col(i);
-        phi(data, meansTrial[i], covariancesTrial[i], condProbAlias);
-        condProbAlias *= weightsTrial[i];
-      }
-
-      // Normalize row-wise.
-      for (size_t i = 0; i < condProb.n_rows; i++)
-        condProb.row(i) /= accu(condProb.row(i));
-
-      // Store the sum of the probability of each state over all the data.
-      arma::vec probRowSums = trans(arma::sum(condProb, 0 /* columnwise */));
-
-      // Calculate the new value of the means using the updated conditional
-      // probabilities.
-      for (size_t i = 0; i < gaussians; i++)
-      {
-        meansTrial[i] = (data * condProb.col(i)) / probRowSums[i];
-
-        // Calculate the new value of the covariances using the updated
-        // conditional probabilities and the updated means.
-        arma::mat tmp = data - (meansTrial[i] *
-            arma::ones<arma::rowvec>(data.n_cols));
-        arma::mat tmp_b = tmp % (arma::ones<arma::vec>(data.n_rows) *
-            trans(condProb.col(i)));
-
-        covariancesTrial[i] = (tmp * trans(tmp_b)) / probRowSums[i];
-      }
-
-      // Calculate the new values for omega using the updated conditional
-      // probabilities.
-      weightsTrial = probRowSums / data.n_cols;
-
-      // Update values of l; calculate new log-likelihood.
-      lOld = l;
-      l = Loglikelihood(data, meansTrial, covariancesTrial, weightsTrial);
-
-      iteration++;
-    }
-
-    Log::Info << "Likelihood of iteration " << iter << " (total " << iteration
-        << " iterations): " << l << std::endl;
-
-    // The trial model is trained.  Is it better than our existing model?
-    if (l > bestL)
-    {
-      bestL = l;
-
-      means = meansTrial;
-      covariances = covariancesTrial;
-      weights = weightsTrial;
-    }
-  }
-
-  Log::Info << "Log likelihood value of the estimated model: " << bestL << "."
-      << std::endl;
-  return;
-}
-
-/**
- * Estimate the probability distribution directly from the given observations,
- * taking into account the probability of each observation actually being from
- * this distribution.
- */
-void GMM::Estimate(const arma::mat& observations,
-                   const arma::vec& probabilities)
-{
-  // This will be very similar to Estimate(const arma::mat&), but there will be
-  // minor differences in how we calculate the means, covariances, and weights.
-  std::vector<arma::vec> meansTrial;
-  std::vector<arma::mat> covariancesTrial;
-  arma::vec weightsTrial;
-
-  arma::mat condProb(observations.n_cols, gaussians);
-
-  double l, lOld, bestL, TINY = 1.0e-4;
-
-  bestL = -DBL_MAX;
-
-  KMeans<> k; // Default KMeans parameters, for now.
-
-  // We will perform ten trials, and then save the trial with the best result
-  // as our trained model.
-  for (size_t iter = 0; iter < 10; iter++)
-  {
-    InitialClustering(k, observations, meansTrial, covariancesTrial,
-        weightsTrial);
-
-    l = Loglikelihood(observations, meansTrial, covariancesTrial, weightsTrial);
-
-    Log::Info << "K-means log-likelihood: " << l << std::endl;
-
-    lOld = -DBL_MAX;
-
-    // Iterate to update the model until no more improvement is found.
-    size_t maxIterations = 300;
-    size_t iteration = 0;
-    while (std::abs(l - lOld) > TINY && iteration < maxIterations)
-    {
-      // Calculate the conditional probabilities of choosing a particular
-      // Gaussian given the observations and the present theta value.
-      for (size_t i = 0; i < gaussians; i++)
-      {
-        // Store conditional probabilities into condProb vector for each
-        // Gaussian.  First we make an alias of the condProb vector.
-        arma::vec condProbAlias = condProb.unsafe_col(i);
-        phi(observations, meansTrial[i], covariancesTrial[i], condProbAlias);
-        condProbAlias *= weightsTrial[i];
-      }
-
-      // Normalize row-wise.
-      for (size_t i = 0; i < condProb.n_rows; i++)
-        condProb.row(i) /= accu(condProb.row(i));
-
-      // This will store the sum of probabilities of each state over all the
-      // observations.
-      arma::vec probRowSums(gaussians);
-
-      // Calculate the new value of the means using the updated conditional
-      // probabilities.
-      for (size_t i = 0; i < gaussians; i++)
-      {
-        // Calculate the sum of probabilities of points, which is the
-        // conditional probability of each point being from Gaussian i
-        // multiplied by the probability of the point being from this mixture
-        // model.
-        probRowSums[i] = accu(condProb.col(i) % probabilities);
-
-        meansTrial[i] = (observations * (condProb.col(i) % probabilities)) /
-            probRowSums[i];
-
-        // Calculate the new value of the covariances using the updated
-        // conditional probabilities and the updated means.
-        arma::mat tmp = observations - (meansTrial[i] *
-            arma::ones<arma::rowvec>(observations.n_cols));
-        arma::mat tmp_b = tmp % (arma::ones<arma::vec>(observations.n_rows) *
-            trans(condProb.col(i) % probabilities));
-
-        covariancesTrial[i] = (tmp * trans(tmp_b)) / probRowSums[i];
-      }
-
-      // Calculate the new values for omega using the updated conditional
-      // probabilities.
-      weightsTrial = probRowSums / accu(probabilities);
-
-      // Update values of l; calculate new log-likelihood.
-      lOld = l;
-      l = Loglikelihood(observations, meansTrial,
-                        covariancesTrial, weightsTrial);
-
-      iteration++;
-    }
-
-    Log::Info << "Likelihood of iteration " << iter << " (total " << iteration
-        << " iterations): " << l << std::endl;
-
-    // The trial model is trained.  Is it better than our existing model?
-    if (l > bestL)
-    {
-      bestL = l;
-
-      means = meansTrial;
-      covariances = covariancesTrial;
-      weights = weightsTrial;
-    }
-  }
-
-  Log::Info << "Log likelihood value of the estimated model: " << bestL << "."
-      << std::endl;
-  return;
-}
-
-/**
- * Classify the given observations as being from an individual component in this
- * GMM.
- */
-void GMM::Classify(const arma::mat& observations,
-                   arma::Col<size_t>& labels) const
-{
-  // This is not the best way to do this!
-
-  // We should not have to fill this with values, because each one should be
-  // overwritten.
-  labels.set_size(observations.n_cols);
-  for (size_t i = 0; i < observations.n_cols; ++i)
-  {
-    // Find maximum probability component.
-    double probability = 0;
-    for (size_t j = 0; j < gaussians; ++j)
-    {
-      double newProb = Probability(observations.unsafe_col(i), j);
-      if (newProb >= probability)
-      {
-        probability = newProb;
-        labels[i] = j;
-      }
-    }
-  }
-}
-
-double GMM::Loglikelihood(const arma::mat& data,
-                          const std::vector<arma::vec>& meansL,
-                          const std::vector<arma::mat>& covariancesL,
-                          const arma::vec& weightsL) const
-{
-  double loglikelihood = 0;
-
-  arma::vec phis;
-  arma::mat likelihoods(gaussians, data.n_cols);
-  for (size_t i = 0; i < gaussians; i++)
-  {
-    phi(data, meansL[i], covariancesL[i], phis);
-    likelihoods.row(i) = weightsL(i) * trans(phis);
-  }
-
-  // Now sum over every point.
-  for (size_t j = 0; j < data.n_cols; j++)
-    loglikelihood += log(accu(likelihoods.col(j)));
-
-  return loglikelihood;
-}

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2012-04-10 18:07:37 UTC (rev 12282)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -10,14 +10,44 @@
 
 #include <mlpack/core.hpp>
 
+// This is the default fitting method class.
+#include "em_fit.hpp"
+
 namespace mlpack {
 namespace gmm /** Gaussian Mixture Models. */ {
 
 /**
  * A Gaussian Mixture Model (GMM). This class uses maximum likelihood loss
- * functions to estimate the parameters of the GMM on a given dataset via the EM
- * algorithm.  The GMM can be trained either with labeled or unlabeled data.
+ * functions to estimate the parameters of the GMM on a given dataset via the
+ * given fitting mechanism, defined by the FittingType template parameter.  The
+ * GMM can be trained using normal data, or data with probabilities of being
+ * from this GMM (see GMM::Estimate() for more information).
  *
+ * The FittingType template class must provide a way for the GMM to train on
+ * data.  It must provide the following two functions:
+ *
+ * @code
+ * void Estimate(const arma::mat& observations,
+ *               std::vector<arma::vec>& means,
+ *               std::vector<arma::mat>& covariances,
+ *               arma::vec& weights);
+ *
+ * void Estimate(const arma::mat& observations,
+ *               const arma::vec& probabilities,
+ *               std::vector<arma::vec>& means,
+ *               std::vector<arma::mat>& covariances,
+ *               arma::vec& weights);
+ * @endcode
+ *
+ * These functions should produce a trained GMM from the given observations and
+ * probabilities.  These may modify the size of the model (by increasing the
+ * size of the mean and covariance vectors as well as the weight vectors), but
+ * the method should expect that these vectors are already set to the size of
+ * the GMM as specified in the constructor.
+ *
+ * For a sample implementation, see the EMFit class; this class uses the EM
+ * algorithm to train a GMM, and is the default fitting type.
+ *
  * The GMM, once trained, can be used to generate random points from the
  * distribution and estimate the probability of points being from the
  * distribution.  The parameters of the GMM can be obtained through the
@@ -26,8 +56,9 @@
  * Example use:
  *
  * @code
- * // Set up a mixture of 5 gaussians in a 4-dimensional space.
- * GMM g(5, 4);
+ * // Set up a mixture of 5 gaussians in a 4-dimensional space (uses the default
+ * // EM fitting mechanism).
+ * GMM<> g(5, 4);
  *
  * // Train the GMM given the data observations.
  * g.Estimate(data);
@@ -39,6 +70,7 @@
  * arma::vec observation = g.Random();
  * @endcode
  */
+template<typename FittingType = EMFit<> >
 class GMM
 {
  private:
@@ -57,12 +89,16 @@
   /**
    * Create an empty Gaussian Mixture Model, with zero gaussians.
    */
-  GMM() : gaussians(0), dimensionality(0)
+  GMM() :
+      gaussians(0),
+      dimensionality(0),
+      localFitter(FittingType()),
+      fitter(localFitter)
   {
     // Warn the user.  They probably don't want to do this.  If this constructor
     // is being used (because it is required by some template classes), the user
     // should know that it is potentially dangerous.
-    Log::Debug << "GMM::GMM(): no parameters given; Estimate() will fail "
+    Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
         << "unless parameters are set." << std::endl;
   }
 
@@ -73,14 +109,36 @@
    * @param gaussians Number of Gaussians in this GMM.
    * @param dimensionality Dimensionality of each Gaussian.
    */
-  GMM(size_t gaussians, size_t dimensionality) :
+  GMM(const size_t gaussians, const size_t dimensionality) :
       gaussians(gaussians),
       dimensionality(dimensionality),
       means(gaussians, arma::vec(dimensionality)),
       covariances(gaussians, arma::mat(dimensionality, dimensionality)),
-      weights(gaussians) { /* nothing to do */ }
+      weights(gaussians),
+      localFitter(FittingType()),
+      fitter(localFitter) { /* Nothing to do. */ }
 
   /**
+   * Create a GMM with the given number of Gaussians, each of which have the
+   * specified dimensionality.  Also, pass in an initialized FittingType class;
+   * this is useful in cases where the FittingType class needs to store some
+   * state.
+   *
+   * @param gaussians Number of Gaussians in this GMM.
+   * @param dimensionality Dimensionality of each Gaussian.
+   * @param fitter Initialized fitting mechanism.
+   */
+  GMM(const size_t gaussians,
+      const size_t dimensionality,
+      FittingType& fitter) :
+      gaussians(gaussians),
+      dimensionality(dimensionality),
+      means(gaussians, arma::vec(dimensionality)),
+      covariances(gaussians, arma::mat(dimensionality, dimensionality)),
+      weights(gaussians),
+      fitter(fitter) { /* Nothing to do. */ }
+
+  /**
    * Create a GMM with the given means, covariances, and weights.
    *
    * @param means Means of the model.
@@ -96,11 +154,41 @@
       covariances(covariances),
       weights(weights) { /* nothing to do */ }
 
+  /**
+   * Copy constructor for GMMs which use different fitting types.
+   */
+  template<typename OtherFittingType>
+  GMM(const GMM<OtherFittingType>& other);
+
+  /**
+   * Copy constructor for GMMs using the same fitting type.  This also copies
+   * the fitter.
+   */
+  GMM(const GMM& other);
+
+  /**
+   * Copy operator for GMMs which use different fitting types.
+   */
+  template<typename OtherFittingType>
+  GMM& operator=(const GMM<OtherFittingType>& other);
+
+  /**
+   * Copy operator for GMMs which use the same fitting type.  This also copies
+   * the fitter.
+   */
+  GMM& operator=(const GMM& other);
+
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
+  //! Modify the number of gaussians in the model.  Careful!  You will have to
+  //! resize the means, covariances, and weights yourself.
+  size_t& Gaussians() { return gaussians; }
 
   //! Return the dimensionality of the model.
   size_t Dimensionality() const { return dimensionality; }
+  //! Modify the dimensionality of the model.  Careful!  You will have to update
+  //! each mean and covariance matrix yourself.
+  size_t& Dimensionality() { return dimensionality; }
 
   //! Return a const reference to the vector of means (mu).
   const std::vector<arma::vec>& Means() const { return means; }
@@ -117,6 +205,11 @@
   //! Return a reference to the a priori weights of each Gaussian.
   arma::vec& Weights() { return weights; }
 
+  //! Return a const reference to the fitting type.
+  const FittingType& Fitter() const { return fitter; }
+  //! Return a reference to the fitting type.
+  FittingType& Fitter() { return fitter; }
+
   /**
    * Return the probability that the given observation came from this
    * distribution.
@@ -145,22 +238,40 @@
 
   /**
    * Estimate the probability distribution directly from the given observations,
-   * using the EM algorithm to obtain the Maximum Likelihood parameter.
+   * using the given algorithm in the FittingType class to fit the data.
    *
+   * The fitting will be performed 'trials' times; from these trials, the model
+   * with the greatest log-likelihood will be selected.  By default, only one
+   * trial is performed.
+   *
+   * @tparam FittingType The type of fitting method which should be used
+   *     (EMFit<> is suggested).
    * @param observations Observations of the model.
+   * @param trials Number of trials to perform; the model in these trials with
+   *      the greatest log-likelihood will be selected.
    */
-  void Estimate(const arma::mat& observations);
+  void Estimate(const arma::mat& observations,
+                const size_t trials = 1);
 
   /**
    * Estimate the probability distribution directly from the given observations,
    * taking into account the probability of each observation actually being from
-   * this distribution.
+   * this distribution, and using the given algorithm in the FittingType class
+   * to fit the data.
    *
+   * The fitting will be performed 'trials' times; from these trials, the model
+   * with the greatest log-likelihood will be selected.  By default, only one
+   * trial is performed.
+   *
    * @param observations Observations of the model.
-   * @param probability Probability of each observation.
+   * @param probabilities Probability of each observation being from this
+   *     distribution.
+   * @param trials Number of trials to perform; the model in these trials with
+   *     the greatest log-likelihood will be selected.
    */
   void Estimate(const arma::mat& observations,
-                const arma::vec& probabilities);
+                const arma::vec& probabilities,
+                const size_t trials = 1);
 
   /**
    * Classify the given observations as being from an individual component in
@@ -191,29 +302,16 @@
    * @param covars Covariances of the given mixture model.
    * @param weights Weights of the given mixture model.
    */
-  double Loglikelihood(const arma::mat& dataPoints,
+  double LogLikelihood(const arma::mat& dataPoints,
                        const std::vector<arma::vec>& means,
                        const std::vector<arma::mat>& covars,
                        const arma::vec& weights) const;
 
-  /**
-   * This function uses the given clustering class and initializes means,
-   * covariances, and weights into the passed objects based on the assignments
-   * of the clustering class.
-   *
-   * @param clusterer Initialized clustering class (must implement void
-   *      Cluster(const arma::mat&, arma::Col<size_t>&)
-   * @param data Dataset to perform clustering on.
-   * @param means Vector to store means in.
-   * @param covars Vector to store covariances in.
-   * @param weights Vector to store weights in.
-   */
-  template<typename ClusteringType>
-  void InitialClustering(const ClusteringType& clusterer,
-                         const arma::mat& data,
-                         std::vector<arma::vec>& means,
-                         std::vector<arma::mat>& covars,
-                         arma::vec& weights) const;
+  //! Locally-stored fitting object; in case the user did not pass one.
+  FittingType localFitter;
+
+  //! Reference to the fitting object we should use.
+  FittingType& fitter;
 };
 
 }; // namespace gmm

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	2012-04-10 18:07:37 UTC (rev 12282)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -1,5 +1,6 @@
 /**
  * @file gmm_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
  * @author Ryan Curtin
  *
  * Implementation of template-based GMM methods.
@@ -13,61 +14,301 @@
 namespace mlpack {
 namespace gmm {
 
-template<typename ClusteringType>
-void GMM::InitialClustering(const ClusteringType& clusterer,
-                            const arma::mat& data,
-                            std::vector<arma::vec>& meansOut,
-                            std::vector<arma::mat>& covarsOut,
-                            arma::vec& weightsOut) const
+// Copy constructor.
+template<typename FittingType>
+template<typename OtherFittingType>
+GMM<FittingType>::GMM(const GMM<OtherFittingType>& other) :
+    gaussians(other.Gaussians()),
+    dimensionality(other.Dimensionality()),
+    means(other.Means()),
+    covariances(other.Covariances()),
+    weights(other.Weights()),
+    localFitter(FittingType()),
+    fitter(localFitter) { /* Nothing to do. */ }
+
+// Copy constructor for when the other GMM uses the same fitting type.
+template<typename FittingType>
+GMM<FittingType>::GMM(const GMM<FittingType>& other) :
+    gaussians(other.Gaussians()),
+    dimensionality(other.Dimensionality()),
+    means(other.Means()),
+    covariances(other.Covariances()),
+    weights(other.Weights()),
+    localFitter(other.Fitter()),
+    fitter(localFitter) { /* Nothing to do. */ }
+
+template<typename FittingType>
+template<typename OtherFittingType>
+GMM<FittingType>& GMM<FittingType>::operator=(
+    const GMM<OtherFittingType>& other)
 {
-  meansOut.resize(gaussians, arma::vec(dimensionality));
-  covarsOut.resize(gaussians, arma::mat(dimensionality, dimensionality));
-  weightsOut.set_size(gaussians);
+  gaussians = other.Gaussians();
+  dimensionality = other.Dimensionality();
+  means = other.Means();
+  covariances = other.Covariances();
+  weights = other.Weights();
 
-  // Assignments from clustering.
-  arma::Col<size_t> assignments;
+  return *this;
+}
 
-  // Run clustering algorithm.
-  clusterer.Cluster(data, gaussians, assignments);
+template<typename FittingType>
+GMM<FittingType>& GMM<FittingType>::operator=(const GMM<FittingType>& other)
+{
+  gaussians = other.Gaussians();
+  dimensionality = other.Dimensionality();
+  means = other.Means();
+  covariances = other.Covariances();
+  weights = other.Weights();
+  localFitter = other.Fitter();
 
-  // Now calculate the means, covariances, and weights.
-  weightsOut.zeros();
+  return *this;
+}
+
+/**
+ * Return the probability of the given observation being from this GMM.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Probability(const arma::vec& observation) const
+{
+  // Sum the probability for each Gaussian in our mixture (and we have to
+  // multiply by the prior for each Gaussian too).
+  double sum = 0;
   for (size_t i = 0; i < gaussians; i++)
+    sum += weights[i] * phi(observation, means[i], covariances[i]);
+
+  return sum;
+}
+
+/**
+ * Return the probability of the given observation being from the given
+ * component in the mixture.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Probability(const arma::vec& observation,
+                                     const size_t component) const
+{
+  // We are only considering one Gaussian component -- so we only need to call
+  // phi() once.  We do consider the prior probability!
+  return weights[component] *
+      phi(observation, means[component], covariances[component]);
+}
+
+/**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ */
+template<typename FittingType>
+arma::vec GMM<FittingType>::Random() const
+{
+  // Determine which Gaussian it will be coming from.
+  double gaussRand = math::Random();
+  size_t gaussian;
+
+  double sumProb = 0;
+  for (size_t g = 0; g < gaussians; g++)
   {
-    meansOut[i].zeros();
-    covarsOut[i].zeros();
+    sumProb += weights(g);
+    if (gaussRand <= sumProb)
+    {
+      gaussian = g;
+      break;
+    }
   }
 
-  // From the assignments, generate our means, covariances, and weights.
-  for (size_t i = 0; i < data.n_cols; i++)
+  return trans(chol(covariances[gaussian])) *
+      arma::randn<arma::vec>(dimensionality) + means[gaussian];
+}
+
+/**
+ * Fit the GMM to the given observations.
+ */
+template<typename FittingType>
+void GMM<FittingType>::Estimate(const arma::mat& observations,
+                                const size_t trials)
+{
+  double bestLikelihood; // This will be reported later.
+
+  // We don't need to store temporary models if we are only doing one trial.
+  if (trials == 1)
   {
-    size_t cluster = assignments[i];
+    // Train the model.  The user will have been warned earlier if the GMM was
+    // initialized with no parameters (0 gaussians, dimensionality of 0).
+    fitter.Estimate(observations, means, covariances, weights);
 
-    // Add this to the relevant mean.
-    meansOut[cluster] += data.col(i);
+    bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+  }
+  else
+  {
+    if (trials == 0)
+      return; // It's what they asked for...
 
-    // Add this to the relevant covariance.
-    covarsOut[cluster] += data.col(i) * trans(data.col(i));
+    // We need to keep temporary copies.  We'll do the first training into the
+    // actual model position, so that if it's the best we don't need to copy it.
+    fitter.Estimate(observations, means, covariances, weights);
 
-    // Now add one to the weights (we will normalize).
-    weightsOut[cluster]++;
+    bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+
+    Log::Debug << "GMM::Estimate(): Log-likelihood of trial 0 is "
+        << bestLikelihood << "." << std::endl;
+
+    // Now the temporary model.
+    std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
+    std::vector<arma::mat> covariancesTrial(gaussians,
+        arma::mat(dimensionality, dimensionality));
+    arma::vec weightsTrial(gaussians);
+
+    for (size_t trial = 1; trial < trials; ++trial)
+    {
+      fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
+
+      // Check to see if the log-likelihood of this one is better.
+      double newLikelihood = LogLikelihood(observations, meansTrial,
+          covariancesTrial, weightsTrial);
+
+      Log::Debug << "GMM::Estimate(): Log-likelihood of trial " << trial
+          << " is " << newLikelihood << "." << std::endl;
+
+      if (newLikelihood > bestLikelihood)
+      {
+        // Save new likelihood and copy new model.
+        bestLikelihood = newLikelihood;
+
+        means = meansTrial;
+        covariances = covariancesTrial;
+        weights = weightsTrial;
+      }
+    }
   }
 
-  // Now normalize the mean and covariance.
-  for (size_t i = 0; i < gaussians; i++)
+  // Report final log-likelihood.
+  Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
+      << bestLikelihood << "." << std::endl;
+}
+
+/**
+ * Fit the GMM to the given observations, each of which has a certain
+ * probability of being from this distribution.
+ */
+template<typename FittingType>
+void GMM<FittingType>::Estimate(const arma::mat& observations,
+                                const arma::vec& probabilities,
+                                const size_t trials)
+{
+  double bestLikelihood; // This will be reported later.
+
+  // We don't need to store temporary models if we are only doing one trial.
+  if (trials == 1)
   {
-    covarsOut[i] -= meansOut[i] * trans(meansOut[i]) /
-        weightsOut[i];
+    // Train the model.  The user will have been warned earlier if the GMM was
+    // initialized with no parameters (0 gaussians, dimensionality of 0).
+    fitter.Estimate(observations, probabilities, means, covariances, weights);
 
-    meansOut[i] /= weightsOut[i];
+    bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+  }
+  else
+  {
+    if (trials == 0)
+      return; // It's what they asked for...
 
-    covarsOut[i] /= (weightsOut[i] > 1) ? weightsOut[i] : 1;
+    // We need to keep temporary copies.  We'll do the first training into the
+    // actual model position, so that if it's the best we don't need to copy it.
+    fitter.Estimate(observations, probabilities, means, covariances, weights);
+
+    bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+
+    Log::Debug << "GMM::Estimate(): Log-likelihood of trial 0 is "
+        << bestLikelihood << "." << std::endl;
+
+    // Now the temporary model.
+    std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
+    std::vector<arma::mat> covariancesTrial(gaussians,
+        arma::mat(dimensionality, dimensionality));
+    arma::vec weightsTrial(gaussians);
+
+    for (size_t trial = 1; trial < trials; ++trial)
+    {
+      fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
+
+      // Check to see if the log-likelihood of this one is better.
+      double newLikelihood = LogLikelihood(observations, meansTrial,
+          covariancesTrial, weightsTrial);
+
+      Log::Debug << "GMM::Estimate(): Log-likelihood of trial " << trial
+          << " is " << newLikelihood << "." << std::endl;
+
+      if (newLikelihood > bestLikelihood)
+      {
+        // Save new likelihood and copy new model.
+        bestLikelihood = newLikelihood;
+
+        means = meansTrial;
+        covariances = covariancesTrial;
+        weights = weightsTrial;
+      }
+    }
   }
 
-  // Finally, normalize weights.
-  weightsOut /= accu(weightsOut);
+  // Report final log-likelihood.
+  Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
+      << bestLikelihood << "." << std::endl;
 }
 
+/**
+ * Classify the given observations as being from an individual component in this
+ * GMM.
+ */
+template<typename FittingType>
+void GMM<FittingType>::Classify(const arma::mat& observations,
+                                arma::Col<size_t>& labels) const
+{
+  // This is not the best way to do this!
+
+  // We should not have to fill this with values, because each one should be
+  // overwritten.
+  labels.set_size(observations.n_cols);
+  for (size_t i = 0; i < observations.n_cols; ++i)
+  {
+    // Find maximum probability component.
+    double probability = 0;
+    for (size_t j = 0; j < gaussians; ++j)
+    {
+      double newProb = Probability(observations.unsafe_col(i), j);
+      if (newProb >= probability)
+      {
+        probability = newProb;
+        labels[i] = j;
+      }
+    }
+  }
+}
+
+/**
+ * Get the log-likelihood of this data's fit to the model.
+ */
+template<typename FittingType>
+double GMM<FittingType>::LogLikelihood(
+    const arma::mat& data,
+    const std::vector<arma::vec>& meansL,
+    const std::vector<arma::mat>& covariancesL,
+    const arma::vec& weightsL) const
+{
+  double loglikelihood = 0;
+
+  arma::vec phis;
+  arma::mat likelihoods(gaussians, data.n_cols);
+  for (size_t i = 0; i < gaussians; i++)
+  {
+    phi(data, meansL[i], covariancesL[i], phis);
+    likelihoods.row(i) = weightsL(i) * trans(phis);
+  }
+
+  // Now sum over every point.
+  for (size_t j = 0; j < data.n_cols; j++)
+    loglikelihood += log(accu(likelihoods.col(j)));
+
+  return loglikelihood;
+}
+
 }; // namespace gmm
 }; // namespace mlpack
 

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2012-04-10 18:07:37 UTC (rev 12282)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2012-04-10 18:47:01 UTC (rev 12283)
@@ -45,7 +45,7 @@
   }
 
   // Calculate mixture of Gaussians.
-  GMM gmm(size_t(gaussians), dataPoints.n_rows);
+  GMM<> gmm(size_t(gaussians), dataPoints.n_rows);
 
   ////// Computing the parameters of the model using the EM algorithm //////
   Timer::Start("em");




More information about the mlpack-svn mailing list