[mlpack-svn] r16949 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 5 08:59:00 EDT 2014
Author: michaelfox99
Date: Tue Aug 5 08:58:59 2014
New Revision: 16949
Log:
Hierarchical GMMs store params in GaussianDistributions. Makes code clearer and simplifies Save/Load.
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp Tue Aug 5 08:58:59 2014
@@ -1,5 +1,6 @@
/**
* @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Michael Fox
* @file gmm.hpp
*
* Defines a Gaussian Mixture model and
@@ -28,14 +29,12 @@
*
* @code
* void Estimate(const arma::mat& observations,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
+ * std::vector<distribution::GaussianDistribution>& dists,
* arma::vec& weights);
*
* void Estimate(const arma::mat& observations,
* const arma::vec& probabilities,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
+ * std::vector<distribution::GaussianDistribution>& dists,
* arma::vec& weights);
* @endcode
*
@@ -78,10 +77,14 @@
size_t gaussians;
//! The dimensionality of the model.
size_t dimensionality;
- //! Vector of means; one for each Gaussian.
+
+ //! Vector of Gaussians
+ std::vector<distribution::GaussianDistribution> dists;
+
+ //! Legacy member data, not used.
std::vector<arma::vec> means;
- //! Vector of covariances; one for each Gaussian.
std::vector<arma::mat> covariances;
+
//! Vector of a priori weights for each Gaussian.
arma::vec weights;
@@ -126,19 +129,16 @@
FittingType& fitter);
/**
- * Create a GMM with the given means, covariances, and weights.
+ * Create a GMM with the given dists and weights.
*
- * @param means Means of the model.
- * @param covariances Covariances of the model.
+ * @param dists Distributions of the model.
* @param weights Weights of the model.
*/
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
+ GMM(const std::vector<distribution::GaussianDistribution> & dists,
const arma::vec& weights) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
+ gaussians(dists.size()),
+ dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+ dists(dists),
weights(weights),
localFitter(FittingType()),
fitter(localFitter) { /* Nothing to do. */ }
@@ -152,14 +152,12 @@
* @param covariances Covariances of the model.
* @param weights Weights of the model.
*/
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
+ GMM(const std::vector<distribution::GaussianDistribution> & dists,
const arma::vec& weights,
FittingType& fitter) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
+ gaussians(dists.size()),
+ dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
+ dists(dists),
weights(weights),
fitter(fitter) { /* Nothing to do. */ }
@@ -202,6 +200,21 @@
*/
void Save(const std::string& filename) const;
+ /**
+ * Load a GMM from a SaveRestoreUtility. The format should be the same
+ * as is generated by the Save() method.
+ *
+ * @param filename Name of SaveRestoreUtility containing model to be loaded.
+ */
+ void Load(const util::SaveRestoreUtility& sr);
+
+ /**
+ * Save a GMM to a SaveRestoreUtility.
+ *
+ * @param SaveRestoreUtility object to save to.
+ */
+ void Save(util::SaveRestoreUtility& sr) const;
+
//! Return the number of gaussians in the model.
size_t Gaussians() const { return gaussians; }
//! Modify the number of gaussians in the model. Careful! You will have to
@@ -214,15 +227,46 @@
//! each mean and covariance matrix yourself.
size_t& Dimensionality() { return dimensionality; }
- //! Return a const reference to the vector of means (mu).
- const std::vector<arma::vec>& Means() const { return means; }
- //! Return a reference to the vector of means (mu).
- std::vector<arma::vec>& Means() { return means; }
-
- //! Return a const reference to the vector of covariance matrices (sigma).
- const std::vector<arma::mat>& Covariances() const { return covariances; }
- //! Return a reference to the vector of covariance matrices (sigma).
- std::vector<arma::mat>& Covariances() { return covariances; }
+ /**
+ * Return a const reference to a component distribution.
+ *
+ * @param i index of component.
+ */
+ const distribution::GaussianDistribution& Component(size_t i) const {
+ return dists[i]; }
+ /**
+ * Return a reference to a component distribution.
+ *
+ * @param i index of component.
+ */
+ distribution::GaussianDistribution& Component(size_t i) { return dists[i]; }
+
+ //! Functions from earlier releases give errors
+ const std::vector<arma::vec>& Means() const
+ {
+ Log::Fatal << "GMM::Means() no longer supported."
+ << "See GMM::Components().";
+ return means;
+ }
+ std::vector<arma::vec>& Means()
+ {
+ Log::Fatal << "GMM::Means() no longer supported."
+ << "See GMM::Components().";
+ return means;
+ }
+ const std::vector<arma::mat>& Covariances() const
+ {
+ Log::Fatal << "GMM::Covariances() no longer supported."
+ << "See GMM::Components().";
+ return covariances;
+ }
+ std::vector<arma::mat>& Covariances()
+ {
+ Log::Fatal << "GMM::Covariances() no longer supported."
+ << "See GMM::Components().";
+ return covariances;
+ }
+
//! Return a const reference to the a priori weights of each Gaussian.
const arma::vec& Weights() const { return weights; }
@@ -338,6 +382,11 @@
* Returns a string representation of this object.
*/
std::string ToString() const;
+
+ /**
+ * Returns a string indicating the type.
+ */
+ static std::string const Type() { return "GMM"; }
private:
/**
@@ -350,8 +399,7 @@
* @param weights Weights of the given mixture model.
*/
double LogLikelihood(const arma::mat& dataPoints,
- const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covars,
+ const std::vector<distribution::GaussianDistribution>& distsL,
const arma::vec& weights) const;
//! Locally-stored fitting object; in case the user did not pass one.
More information about the mlpack-svn
mailing list