[mlpack-svn] r14891 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 11 18:36:55 EDT 2013


Author: rcurtin
Date: 2013-04-11 18:36:55 -0400 (Thu, 11 Apr 2013)
New Revision: 14891

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
   mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
Log:
Parameterize perturbation, tolerance, and maximum number of iterations.


Modified: mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp	2013-04-10 21:30:24 UTC (rev 14890)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit.hpp	2013-04-11 22:36:55 UTC (rev 14891)
@@ -35,10 +35,20 @@
  public:
   /**
    * Construct the EMFit object, optionally passing an InitialClusteringType
-   * object (just in case it needs to store state).
+   * object (just in case it needs to store state).  Setting the maximum number
+   * of iterations to 0 means that the EM algorithm will iterate until
+   * convergence (with the given tolerance).
+   *
+   * @param clusterer Object which will perform the initial clustering.
+   * @param maxIterations Maximum number of iterations for EM.
+   * @param tolerance Log-likelihood tolerance required for convergence.
+   * @param perturbation Value to add to zero-valued diagonal covariance entries
+   *     to ensure positive-definiteness.
    */
-  EMFit(InitialClusteringType clusterer = InitialClusteringType()) :
-      clusterer(clusterer) { /* Nothing to do. */ }
+  EMFit(const size_t maxIterations = 300,
+        const double tolerance = 1e-10,
+        const double perturbation = 1e-30,
+        InitialClusteringType clusterer = InitialClusteringType());
 
   /**
    * Fit the observations to a Gaussian mixture model (GMM) using the EM
@@ -73,6 +83,26 @@
                 std::vector<arma::mat>& covariances,
                 arma::vec& weights);
 
+  //! Get the clusterer.
+  const InitialClusteringType& Clusterer() const { return clusterer; }
+  //! Modify the clusterer.
+  InitialClusteringType& Clusterer() { return clusterer; }
+
+  //! Get the maximum number of iterations of the EM algorithm.
+  size_t MaxIterations() const { return maxIterations; }
+  //! Modify the maximum number of iterations of the EM algorithm.
+  size_t& MaxIterations() { return maxIterations; }
+
+  //! Get the tolerance for the convergence of the EM algorithm.
+  double Tolerance() const { return tolerance; }
+  //! Modify the tolerance for the convergence of the EM algorithm.
+  double& Tolerance() { return tolerance; }
+
+  //! Get the perturbation added to zero diagonal covariance elements.
+  double Perturbation() const { return perturbation; }
+  //! Modify the perturbation added to zero diagonal covariance elements.
+  double& Perturbation() { return perturbation; }
+
  private:
   /**
    * Run the clusterer, and then turn the cluster assignments into Gaussians.
@@ -104,6 +134,13 @@
                        const std::vector<arma::mat>& covariances,
                        const arma::vec& weights) const;
 
+  //! Maximum iterations of EM algorithm.
+  size_t maxIterations;
+  //! Tolerance for convergence of EM.
+  double tolerance;
+  //! Perturbation to add to zero-valued diagonal covariance values.
+  double perturbation;
+  //! Object which will perform the clustering.
   InitialClusteringType clusterer;
 };
 

Modified: mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	2013-04-10 21:30:24 UTC (rev 14890)
+++ mlpack/trunk/src/mlpack/methods/gmm/em_fit_impl.hpp	2013-04-11 22:36:55 UTC (rev 14891)
@@ -16,7 +16,19 @@
 namespace mlpack {
 namespace gmm {
 
+//! Constructor.
 template<typename InitialClusteringType>
+EMFit<InitialClusteringType>::EMFit(const size_t maxIterations,
+                                    const double tolerance,
+                                    const double perturbation,
+                                    InitialClusteringType clusterer) :
+    maxIterations(maxIterations),
+    tolerance(tolerance),
+    perturbation(perturbation),
+    clusterer(clusterer)
+{ /* Nothing to do. */ }
+
+template<typename InitialClusteringType>
 void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
                                             std::vector<arma::vec>& means,
                                             std::vector<arma::mat>& covariances,
@@ -33,9 +45,8 @@
   arma::mat condProb(observations.n_cols, means.size());
 
   // Iterate to update the model until no more improvement is found.
-  size_t maxIterations = 300;
-  size_t iteration = 0;
-  while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+  size_t iteration = 1;
+  while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
   {
     // Calculate the conditional probabilities of choosing a particular
     // Gaussian given the observations and the present theta value.
@@ -82,7 +93,7 @@
         {
           Log::Debug << "Covariance " << i << " has zero in diagonal element "
               << d << "!  Adding perturbation." << std::endl;
-          covariances[i](d, d) += 1e-50;
+          covariances[i](d, d) += perturbation;
         }
       }
     }
@@ -117,9 +128,8 @@
   arma::mat condProb(observations.n_cols, means.size());
 
   // Iterate to update the model until no more improvement is found.
-  size_t maxIterations = 300;
-  size_t iteration = 0;
-  while (std::abs(l - lOld) > 1e-10 && iteration < maxIterations)
+  size_t iteration = 1;
+  while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
   {
     // Calculate the conditional probabilities of choosing a particular
     // Gaussian given the observations and the present theta value.
@@ -174,7 +184,7 @@
         {
           Log::Debug << "Covariance " << i << " has zero in diagonal element "
             << d << "!  Adding perturbation." << std::endl;
-          covariances[i](d, d) += 1e-50;
+          covariances[i](d, d) += perturbation;
         }
       }
     }
@@ -253,7 +263,7 @@
       {
         Log::Debug << "Covariance " << i << " has zero in diagonal element "
           << d << "!  Adding perturbation." << std::endl;
-        covariances[i](d, d) += 1e-50;
+        covariances[i](d, d) += perturbation;
       }
     }
   }




More information about the mlpack-svn mailing list