[mlpack-svn] r16949 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 5 08:59:00 EDT 2014


Author: michaelfox99
Date: Tue Aug  5 08:58:59 2014
New Revision: 16949

Log:
Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load.


Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	Tue Aug  5 08:58:59 2014
@@ -1,5 +1,6 @@
 /**
  * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Michael Fox
  * @file gmm.hpp
  *
  * Defines a Gaussian Mixture model and
@@ -28,14 +29,12 @@
  *
  * @code
  * void Estimate(const arma::mat& observations,
- *               std::vector<arma::vec>& means,
- *               std::vector<arma::mat>& covariances,
+ *               std::vector<distribution::GaussianDistribution>& dists,
  *               arma::vec& weights);
  *
  * void Estimate(const arma::mat& observations,
  *               const arma::vec& probabilities,
- *               std::vector<arma::vec>& means,
- *               std::vector<arma::mat>& covariances,
+ *               std::vector<distribution::GaussianDistribution>& dists,
  *               arma::vec& weights);
  * @endcode
  *
@@ -78,10 +77,14 @@
   size_t gaussians;
   //! The dimensionality of the model.
   size_t dimensionality;
-  //! Vector of means; one for each Gaussian.
+  
+  //! Vector of Gaussians
+  std::vector<distribution::GaussianDistribution> dists;
+  
+  //! Legacy member data, not used.
   std::vector<arma::vec> means;
-  //! Vector of covariances; one for each Gaussian.
   std::vector<arma::mat> covariances;
+  
   //! Vector of a priori weights for each Gaussian.
   arma::vec weights;
 
@@ -126,19 +129,16 @@
       FittingType& fitter);
 
   /**
-   * Create a GMM with the given means, covariances, and weights.
+   * Create a GMM with the given dists and weights.
    *
-   * @param means Means of the model.
-   * @param covariances Covariances of the model.
+   * @param dists Distributions of the model.
    * @param weights Weights of the model.
    */
-  GMM(const std::vector<arma::vec>& means,
-      const std::vector<arma::mat>& covariances,
+  GMM(const std::vector<distribution::GaussianDistribution> & dists,
       const arma::vec& weights) :
-      gaussians(means.size()),
-      dimensionality((!means.empty()) ? means[0].n_elem : 0),
-      means(means),
-      covariances(covariances),
+      gaussians(dists.size()),
+      dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+      dists(dists),
       weights(weights),
       localFitter(FittingType()),
       fitter(localFitter) { /* Nothing to do. */ }
@@ -152,14 +152,12 @@
    * @param covariances Covariances of the model.
    * @param weights Weights of the model.
    */
-  GMM(const std::vector<arma::vec>& means,
-      const std::vector<arma::mat>& covariances,
+  GMM(const std::vector<distribution::GaussianDistribution> & dists,
       const arma::vec& weights,
       FittingType& fitter) :
-      gaussians(means.size()),
-      dimensionality((!means.empty()) ? means[0].n_elem : 0),
-      means(means),
-      covariances(covariances),
+      gaussians(dists.size()),
+      dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+      dists(dists),
       weights(weights),
       fitter(fitter) { /* Nothing to do. */ }
 
@@ -202,6 +200,21 @@
    */
   void Save(const std::string& filename) const;
 
+  /**
+   * Load a GMM from a SaveRestoreUtility.  The format should be the same
+   * as is generated by the Save() method.
+   *
+   * @param filename Name of SaveRestoreUtility containing model to be loaded.
+   */
+  void Load(const util::SaveRestoreUtility& sr);
+
+  /**
+   * Save a GMM to a SaveRestoreUtility.
+   *
+   * @param SaveRestoreUtility object to save to.
+   */
+  void Save(util::SaveRestoreUtility& sr) const;
+    
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
   //! Modify the number of gaussians in the model.  Careful!  You will have to
@@ -214,15 +227,46 @@
   //! each mean and covariance matrix yourself.
   size_t& Dimensionality() { return dimensionality; }
 
-  //! Return a const reference to the vector of means (mu).
-  const std::vector<arma::vec>& Means() const { return means; }
-  //! Return a reference to the vector of means (mu).
-  std::vector<arma::vec>& Means() { return means; }
-
-  //! Return a const reference to the vector of covariance matrices (sigma).
-  const std::vector<arma::mat>& Covariances() const { return covariances; }
-  //! Return a reference to the vector of covariance matrices (sigma).
-  std::vector<arma::mat>& Covariances() { return covariances; }
+  /**
+   * Return a const reference to a component distribution.
+   *
+   * @param i index of component.
+   */  
+  const distribution::GaussianDistribution& Component(size_t i) const {
+      return dists[i]; }
+  /**
+   * Return a reference to a component distribution.
+   *
+   * @param i index of component.
+   */  
+  distribution::GaussianDistribution& Component(size_t i) { return dists[i]; }
+
+  //! Functions from earlier releases give errors
+  const std::vector<arma::vec>& Means() const
+  {
+    Log::Fatal << "GMM::Means() no longer supported."
+        << "See GMM::Components().";
+    return means;
+  }
+  std::vector<arma::vec>& Means()
+  {
+    Log::Fatal << "GMM::Means() no longer supported."
+        << "See GMM::Components().";
+    return means;
+  }
+  const std::vector<arma::mat>& Covariances() const
+  {
+    Log::Fatal << "GMM::Covariances() no longer supported."
+        << "See GMM::Components().";
+    return covariances;
+  }
+  std::vector<arma::mat>& Covariances()
+  {
+    Log::Fatal << "GMM::Covariances() no longer supported."
+        << "See GMM::Components().";
+    return covariances;
+  }
+
 
   //! Return a const reference to the a priori weights of each Gaussian.
   const arma::vec& Weights() const { return weights; }
@@ -338,6 +382,11 @@
    * Returns a string representation of this object.
    */
   std::string ToString() const;
+  
+  /**
+   * Returns a string indicating the type.
+   */
+  static std::string const Type() { return "GMM"; }
 
  private:
   /**
@@ -350,8 +399,7 @@
    * @param weights Weights of the given mixture model.
    */
   double LogLikelihood(const arma::mat& dataPoints,
-                       const std::vector<arma::vec>& means,
-                       const std::vector<arma::mat>& covars,
+                       const std::vector<distribution::GaussianDistribution>& distsL,
                        const arma::vec& weights) const;
 
   //! Locally-stored fitting object; in case the user did not pass one.



More information about the mlpack-svn mailing list