[mlpack-svn] r14891 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 11 18:36:55 EDT 2013
Author: rcurtin
Date: 2013-04-11 18:36:55 -0400 (Thu, 11 Apr 2013)
New Revision: 14891
Modified:
mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
Log:
Parameterize perturbation, tolerance, and maximum number of iterations.
Modified: mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp 2013-04-10 21:30:24 UTC (rev 14890)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp 2013-04-11 22:36:55 UTC (rev 14891)
@@ -35,10 +35,20 @@
public:
/**
* Construct the EMFit object, optionally passing an InitialClusteringType
- * object (just in case it needs to store state).
+ * object (just in case it needs to store state). Setting the maximum number
+ * of iterations to 0 means that the EM algorithm will iterate until
+ * convergence (with the given tolerance).
+ *
+ * @param clusterer Object which will perform the initial clustering.
+ * @param maxIterations Maximum number of iterations for EM.
+ * @param tolerance Log-likelihood tolerance required for convergence.
+ * @param perturbation Value to add to zero-valued diagonal covariance entries
+ * to ensure positive-definiteness.
*/
- EMFit(InitialClusteringType clusterer = InitialClusteringType()) :
- clusterer(clusterer) { /* Nothing to do. */ }
+ EMFit(const size_t maxIterations = 300,
+ const double tolerance = 1e-10,
+ const double perturbation = 1e-30,
+ InitialClusteringType clusterer = InitialClusteringType());
/**
* Fit the observations to a Gaussian mixture model (GMM) using the EM
@@ -73,6 +83,26 @@
std::vector<arma::mat>& covariances,
arma::vec& weights);
+ //! Get the clusterer.
+ const InitialClusteringType& Clusterer() const { return clusterer; }
+ //! Modify the clusterer.
+ InitialClusteringType& Clusterer() { return clusterer; }
+
+ //! Get the maximum number of iterations of the EM algorithm.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations of the EM algorithm.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the tolerance for the convergence of the EM algorithm.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance for the convergence of the EM algorithm.
+ double& Tolerance() { return tolerance; }
+
+ //! Get the perturbation added to zero diagonal covariance elements.
+ double Perturbation() const { return perturbation; }
+ //! Modify the perturbation added to zero diagonal covariance elements.
+ double& Perturbation() { return perturbation; }
+
private:
/**
* Run the clusterer, and then turn the cluster assignments into Gaussians.
@@ -104,6 +134,13 @@
const std::vector<arma::mat>& covariances,
const arma::vec& weights) const;
+ //! Maximum iterations of EM algorithm.
+ size_t maxIterations;
+ //! Tolerance for convergence of EM.
+ double tolerance;
+ //! Perturbation to add to zero-valued diagonal covariance values.
+ double perturbation;
+ //! Object which will perform the clustering.
InitialClusteringType clusterer;
};
Modified: mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp 2013-04-10 21:30:24 UTC (rev 14890)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp 2013-04-11 22:36:55 UTC (rev 14891)
@@ -16,7 +16,19 @@
namespace mlpack {
namespace gmm {
+//! Constructor.
template<typename InitialClusteringType>
+EMFit<InitialClusteringType>::EMFit(const size_t maxIterations,
+ const double tolerance,
+ const double perturbation,
+ InitialClusteringType clusterer) :
+ maxIterations(maxIterations),
+ tolerance(tolerance),
+ perturbation(perturbation),
+ clusterer(clusterer)
+{ /* Nothing to do. */ }
+
+template<typename InitialClusteringType>
void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
std::vector<arma::vec>& means,
std::vector<arma::mat>& covariances,
@@ -33,9 +45,8 @@
arma::mat condProb(observations.n_cols, means.size());
// Iterate to update the model until no more improvement is found.
- size_t maxIterations = 300;
- size_t iteration = 0;
- while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+ size_t iteration = 1;
+ while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
{
// Calculate the conditional probabilities of choosing a particular
// Gaussian given the observations and the present theta value.
@@ -82,7 +93,7 @@
{
Log::Debug << "Covariance " << i << " has zero in diagonal element "
<< d << "! Adding perturbation." << std::endl;
- covariances[i](d, d) += 1e-50;
+ covariances[i](d, d) += perturbation;
}
}
}
@@ -117,9 +128,8 @@
arma::mat condProb(observations.n_cols, means.size());
// Iterate to update the model until no more improvement is found.
- size_t maxIterations = 300;
- size_t iteration = 0;
- while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+ size_t iteration = 1;
+ while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
{
// Calculate the conditional probabilities of choosing a particular
// Gaussian given the observations and the present theta value.
@@ -174,7 +184,7 @@
{
Log::Debug << "Covariance " << i << " has zero in diagonal element "
<< d << "! Adding perturbation." << std::endl;
- covariances[i](d, d) += 1e-50;
+ covariances[i](d, d) += perturbation;
}
}
}
@@ -253,7 +263,7 @@
{
Log::Debug << "Covariance " << i << " has zero in diagonal element "
<< d << "! Adding perturbation." << std::endl;
- covariances[i](d, d) += 1e-50;
+ covariances[i](d, d) += perturbation;
}
}
}
More information about the mlpack-svn
mailing list