[mlpack-git] master: Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load. (42e57d0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:56:03 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 42e57d0c3a545cc36aede400d45851c90109d738
Author: michaelfox99 <michaelfox99 at gmail.com>
Date:   Tue Aug 5 12:58:59 2014 +0000

    Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load.


>---------------------------------------------------------------

42e57d0c3a545cc36aede400d45851c90109d738
 src/mlpack/methods/gmm/gmm.hpp | 110 +++++++++++++++++++++++++++++------------
 1 file changed, 79 insertions(+), 31 deletions(-)

diff --git a/src/mlpack/methods/gmm/gmm.hpp b/src/mlpack/methods/gmm/gmm.hpp
index 15341d1..3de507b 100644
--- a/src/mlpack/methods/gmm/gmm.hpp
+++ b/src/mlpack/methods/gmm/gmm.hpp
@@ -1,5 +1,6 @@
 /**
  * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Michael Fox
  * @file gmm.hpp
  *
  * Defines a Gaussian Mixture model and
@@ -28,14 +29,12 @@ namespace gmm /** Gaussian Mixture Models. */ {
  *
  * @code
  * void Estimate(const arma::mat& observations,
- *               std::vector<arma::vec>& means,
- *               std::vector<arma::mat>& covariances,
+ *               std::vector<distribution::GaussianDistribution>& dists,
  *               arma::vec& weights);
  *
  * void Estimate(const arma::mat& observations,
  *               const arma::vec& probabilities,
- *               std::vector<arma::vec>& means,
- *               std::vector<arma::mat>& covariances,
+ *               std::vector<distribution::GaussianDistribution>& dists,
  *               arma::vec& weights);
  * @endcode
  *
@@ -78,10 +77,14 @@ class GMM
   size_t gaussians;
   //! The dimensionality of the model.
   size_t dimensionality;
-  //! Vector of means; one for each Gaussian.
+  
+  //! Vector of Gaussians
+  std::vector<distribution::GaussianDistribution> dists;
+  
+  //! Legacy member data, not used.
   std::vector<arma::vec> means;
-  //! Vector of covariances; one for each Gaussian.
   std::vector<arma::mat> covariances;
+  
   //! Vector of a priori weights for each Gaussian.
   arma::vec weights;
 
@@ -126,19 +129,16 @@ class GMM
       FittingType& fitter);
 
   /**
-   * Create a GMM with the given means, covariances, and weights.
+   * Create a GMM with the given dists and weights.
    *
-   * @param means Means of the model.
-   * @param covariances Covariances of the model.
+   * @param dists Distributions of the model.
    * @param weights Weights of the model.
    */
-  GMM(const std::vector<arma::vec>& means,
-      const std::vector<arma::mat>& covariances,
+  GMM(const std::vector<distribution::GaussianDistribution> & dists,
       const arma::vec& weights) :
-      gaussians(means.size()),
-      dimensionality((!means.empty()) ? means[0].n_elem : 0),
-      means(means),
-      covariances(covariances),
+      gaussians(dists.size()),
+      dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+      dists(dists),
       weights(weights),
       localFitter(FittingType()),
       fitter(localFitter) { /* Nothing to do. */ }
@@ -152,14 +152,12 @@ class GMM
    * @param covariances Covariances of the model.
    * @param weights Weights of the model.
    */
-  GMM(const std::vector<arma::vec>& means,
-      const std::vector<arma::mat>& covariances,
+  GMM(const std::vector<distribution::GaussianDistribution> & dists,
       const arma::vec& weights,
       FittingType& fitter) :
-      gaussians(means.size()),
-      dimensionality((!means.empty()) ? means[0].n_elem : 0),
-      means(means),
-      covariances(covariances),
+      gaussians(dists.size()),
+      dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+      dists(dists),
       weights(weights),
       fitter(fitter) { /* Nothing to do. */ }
 
@@ -202,6 +200,21 @@ class GMM
    */
   void Save(const std::string& filename) const;
 
+  /**
+   * Load a GMM from a SaveRestoreUtility.  The format should be the same
+   * as is generated by the Save() method.
+   *
+   * @param filename Name of SaveRestoreUtility containing model to be loaded.
+   */
+  void Load(const util::SaveRestoreUtility& sr);
+
+  /**
+   * Save a GMM to a SaveRestoreUtility.
+   *
+   * @param SaveRestoreUtility object to save to.
+   */
+  void Save(util::SaveRestoreUtility& sr) const;
+    
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
   //! Modify the number of gaussians in the model.  Careful!  You will have to
@@ -214,15 +227,46 @@ class GMM
   //! each mean and covariance matrix yourself.
   size_t& Dimensionality() { return dimensionality; }
 
-  //! Return a const reference to the vector of means (mu).
-  const std::vector<arma::vec>& Means() const { return means; }
-  //! Return a reference to the vector of means (mu).
-  std::vector<arma::vec>& Means() { return means; }
+  /**
+   * Return a const reference to a component distribution.
+   *
+   * @param i index of component.
+   */  
+  const distribution::GaussianDistribution& Component(size_t i) const {
+      return dists[i]; }
+  /**
+   * Return a reference to a component distribution.
+   *
+   * @param i index of component.
+   */  
+  distribution::GaussianDistribution& Component(size_t i) { return dists[i]; }
+
+  //! Functions from earlier releases give errors
+  const std::vector<arma::vec>& Means() const
+  {
+    Log::Fatal << "GMM::Means() no longer supported."
+        << "See GMM::Components().";
+    return means;
+  }
+  std::vector<arma::vec>& Means()
+  {
+    Log::Fatal << "GMM::Means() no longer supported."
+        << "See GMM::Components().";
+    return means;
+  }
+  const std::vector<arma::mat>& Covariances() const
+  {
+    Log::Fatal << "GMM::Covariances() no longer supported."
+        << "See GMM::Components().";
+    return covariances;
+  }
+  std::vector<arma::mat>& Covariances()
+  {
+    Log::Fatal << "GMM::Covariances() no longer supported."
+        << "See GMM::Components().";
+    return covariances;
+  }
 
-  //! Return a const reference to the vector of covariance matrices (sigma).
-  const std::vector<arma::mat>& Covariances() const { return covariances; }
-  //! Return a reference to the vector of covariance matrices (sigma).
-  std::vector<arma::mat>& Covariances() { return covariances; }
 
   //! Return a const reference to the a priori weights of each Gaussian.
   const arma::vec& Weights() const { return weights; }
@@ -339,6 +383,11 @@ class GMM
    */
   std::string ToString() const;
   
+  /**
+   * Returns a string indicating the type.
+   */
+  static std::string const Type() { return "GMM"; }
+
  private:
   /**
    * This function computes the loglikelihood of the given model.  This function
@@ -350,8 +399,7 @@ class GMM
    * @param weights Weights of the given mixture model.
    */
   double LogLikelihood(const arma::mat& dataPoints,
-                       const std::vector<arma::vec>& means,
-                       const std::vector<arma::mat>& covars,
+                       const std::vector<distribution::GaussianDistribution>& distsL,
                        const arma::vec& weights) const;
 
   //! Locally-stored fitting object; in case the user did not pass one.



More information about the mlpack-git mailing list