[mlpack-git] master: Refactor GMM: only use FittingType template parameter for Train(). (d924562)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Dec 18 11:43:16 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5ba11bc90223b55eecd5da4cfbe86c8fc40637a5...df229e45a5bd7842fe019e9d49ed32f13beb6aaa

>---------------------------------------------------------------

commit d924562c0cd6846da122f1fc1a28318b660b8e47
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Dec 18 15:51:35 2015 +0000

    Refactor GMM: only use FittingType template parameter for Train().


>---------------------------------------------------------------

d924562c0cd6846da122f1fc1a28318b660b8e47
 src/mlpack/methods/gmm/CMakeLists.txt              |   1 +
 src/mlpack/methods/gmm/gmm.cpp                     | 152 ++++++++++++
 src/mlpack/methods/gmm/gmm.hpp                     | 110 ++-------
 src/mlpack/methods/gmm/gmm_impl.hpp                | 256 ++-------------------
 .../gmm/{gmm_main.cpp => gmm_train_main.cpp}       |   0
 src/mlpack/methods/gmm/gmm_util.hpp                |  45 ----
 src/mlpack/methods/hmm/hmm_train_main.cpp          |   2 +-
 src/mlpack/methods/hmm/hmm_util_impl.hpp           |   4 +-
 src/mlpack/tests/gmm_test.cpp                      |  26 +--
 src/mlpack/tests/hmm_test.cpp                      |  16 +-
 10 files changed, 213 insertions(+), 399 deletions(-)

diff --git a/src/mlpack/methods/gmm/CMakeLists.txt b/src/mlpack/methods/gmm/CMakeLists.txt
index 93728cc..ce697e4 100644
--- a/src/mlpack/methods/gmm/CMakeLists.txt
+++ b/src/mlpack/methods/gmm/CMakeLists.txt
@@ -2,6 +2,7 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   gmm.hpp
+  gmm.cpp
   gmm_impl.hpp
   em_fit.hpp
   em_fit_impl.hpp
diff --git a/src/mlpack/methods/gmm/gmm.cpp b/src/mlpack/methods/gmm/gmm.cpp
new file mode 100644
index 0000000..496dc12
--- /dev/null
+++ b/src/mlpack/methods/gmm/gmm.cpp
@@ -0,0 +1,152 @@
+/**
+ * @file gmm.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Ryan Curtin
+ * @author Michael Fox
+ *
+ * Implementation of template-based GMM methods.
+ */
+#include "gmm.hpp"
+
+namespace mlpack {
+namespace gmm {
+
+/**
+ * Create a GMM with the given number of Gaussians, each of which have the
+ * specified dimensionality.  The means and covariances will be set to 0.
+ *
+ * @param gaussians Number of Gaussians in this GMM.
+ * @param dimensionality Dimensionality of each Gaussian.
+ */
+GMM::GMM(const size_t gaussians, const size_t dimensionality) :
+    gaussians(gaussians),
+    dimensionality(dimensionality),
+    dists(gaussians, distribution::GaussianDistribution(dimensionality)),
+    weights(gaussians)
+{
+  // Set equal weights.  Technically this model is still valid, but only barely.
+  weights.fill(1.0 / gaussians);
+}
+
+// Copy constructor for when the other GMM uses the same fitting type.
+GMM::GMM(const GMM& other) :
+    gaussians(other.Gaussians()),
+    dimensionality(other.dimensionality),
+    dists(other.dists),
+    weights(other.weights) { /* Nothing to do. */ }
+
+GMM& GMM::operator=(const GMM& other)
+{
+  gaussians = other.gaussians;
+  dimensionality = other.dimensionality;
+  dists = other.dists;
+  weights = other.weights;
+
+  return *this;
+}
+
+/**
+ * Return the probability of the given observation being from this GMM.
+ */
+double GMM::Probability(const arma::vec& observation) const
+{
+  // Sum the probability for each Gaussian in our mixture (and we have to
+  // multiply by the prior for each Gaussian too).
+  double sum = 0;
+  for (size_t i = 0; i < gaussians; i++)
+    sum += weights[i] * dists[i].Probability(observation);
+
+  return sum;
+}
+
+/**
+ * Return the probability of the given observation being from the given
+ * component in the mixture.
+ */
+double GMM::Probability(const arma::vec& observation,
+                        const size_t component) const
+{
+  // We are only considering one Gaussian component -- so we only need to call
+  // Probability() once.  We do consider the prior probability!
+  return weights[component] * dists[component].Probability(observation);
+}
+
+/**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ */
+arma::vec GMM::Random() const
+{
+  // Determine which Gaussian it will be coming from.
+  double gaussRand = math::Random();
+  size_t gaussian = 0;
+
+  double sumProb = 0;
+  for (size_t g = 0; g < gaussians; g++)
+  {
+    sumProb += weights(g);
+    if (gaussRand <= sumProb)
+    {
+      gaussian = g;
+      break;
+    }
+  }
+
+  return trans(chol(dists[gaussian].Covariance())) *
+      arma::randn<arma::vec>(dimensionality) + dists[gaussian].Mean();
+}
+
+/**
+ * Classify the given observations as being from an individual component in this
+ * GMM.
+ */
+void GMM::Classify(const arma::mat& observations,
+                   arma::Row<size_t>& labels) const
+{
+  // This is not the best way to do this!
+
+  // We should not have to fill this with values, because each one should be
+  // overwritten.
+  labels.set_size(observations.n_cols);
+  for (size_t i = 0; i < observations.n_cols; ++i)
+  {
+    // Find maximum probability component.
+    double probability = 0;
+    for (size_t j = 0; j < gaussians; ++j)
+    {
+      double newProb = Probability(observations.unsafe_col(i), j);
+      if (newProb >= probability)
+      {
+        probability = newProb;
+        labels[i] = j;
+      }
+    }
+  }
+}
+
+/**
+ * Get the log-likelihood of this data's fit to the model.
+ */
+double GMM::LogLikelihood(
+    const arma::mat& data,
+    const std::vector<distribution::GaussianDistribution>& distsL,
+    const arma::vec& weightsL) const
+{
+  double loglikelihood = 0;
+  arma::vec phis;
+  arma::mat likelihoods(gaussians, data.n_cols);
+
+  for (size_t i = 0; i < gaussians; i++)
+  {
+    distsL[i].Probability(data, phis);
+    likelihoods.row(i) = weightsL(i) * trans(phis);
+  }
+
+  // Now sum over every point.
+  for (size_t j = 0; j < data.n_cols; j++)
+    loglikelihood += log(accu(likelihoods.col(j)));
+  return loglikelihood;
+}
+
+} // namespace gmm
+} // namespace mlpack
diff --git a/src/mlpack/methods/gmm/gmm.hpp b/src/mlpack/methods/gmm/gmm.hpp
index 99409b9..c1cdda7 100644
--- a/src/mlpack/methods/gmm/gmm.hpp
+++ b/src/mlpack/methods/gmm/gmm.hpp
@@ -24,8 +24,9 @@ namespace gmm /** Gaussian Mixture Models. */ {
  * GMM can be trained using normal data, or data with probabilities of being
  * from this GMM (see GMM::Train() for more information).
  *
- * The FittingType template class must provide a way for the GMM to train on
- * data.  It must provide the following two functions:
+ * The Train() method uses a template type 'FittingType'.  The FittingType
+ * template class must provide a way for the GMM to train on data.  It must
+ * provide the following two functions:
  *
  * @code
  * void Estimate(const arma::mat& observations,
@@ -45,7 +46,8 @@ namespace gmm /** Gaussian Mixture Models. */ {
  * the GMM as specified in the constructor.
  *
  * For a sample implementation, see the EMFit class; this class uses the EM
- * algorithm to train a GMM, and is the default fitting type.
+ * algorithm to train a GMM, and is the default fitting type for the Train()
+ * method.
  *
  * The GMM, once trained, can be used to generate random points from the
  * distribution and estimate the probability of points being from the
@@ -55,12 +57,12 @@ namespace gmm /** Gaussian Mixture Models. */ {
  * Example use:
  *
  * @code
- * // Set up a mixture of 5 gaussians in a 4-dimensional space (uses the default
- * // EM fitting mechanism).
- * GMM<> g(5, 4);
+ * // Set up a mixture of 5 gaussians in a 4-dimensional space.
+ * GMM g(5, 4);
  *
- * // Train the GMM given the data observations.
- * g.Estimate(data);
+ * // Train the GMM given the data observations, using the default EM fitting
+ * // mechanism.
+ * g.Train(data);
  *
  * // Get the probability of 'observation' being observed from this GMM.
  * double probability = g.Probability(observation);
@@ -69,7 +71,6 @@ namespace gmm /** Gaussian Mixture Models. */ {
  * arma::vec observation = g.Random();
  * @endcode
  */
-template<typename FittingType = EMFit<>>
 class GMM
 {
  private:
@@ -84,20 +85,13 @@ class GMM
   //! Vector of a priori weights for each Gaussian.
   arma::vec weights;
 
-  //! Locally-stored fitting object; in case the user did not pass one.
-  FittingType* fitter;
-  //! Whether or not we own the fitter.
-  bool ownsFitter;
-
  public:
   /**
    * Create an empty Gaussian Mixture Model, with zero gaussians.
    */
   GMM() :
       gaussians(0),
-      dimensionality(0),
-      fitter(new FittingType()),
-      ownsFitter(true)
+      dimensionality(0)
   {
     // Warn the user.  They probably don't want to do this.  If this constructor
     // is being used (because it is required by some template classes), the user
@@ -116,20 +110,6 @@ class GMM
   GMM(const size_t gaussians, const size_t dimensionality);
 
   /**
-   * Create a GMM with the given number of Gaussians, each of which have the
-   * specified dimensionality.  Also, pass in an initialized FittingType class;
-   * this is useful in cases where the FittingType class needs to store some
-   * state.
-   *
-   * @param gaussians Number of Gaussians in this GMM.
-   * @param dimensionality Dimensionality of each Gaussian.
-   * @param fitter Initialized fitting mechanism.
-   */
-  GMM(const size_t gaussians,
-      const size_t dimensionality,
-      FittingType& fitter);
-
-  /**
    * Create a GMM with the given dists and weights.
    *
    * @param dists Distributions of the model.
@@ -140,67 +120,18 @@ class GMM
       gaussians(dists.size()),
       dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
       dists(dists),
-      weights(weights),
-      fitter(new FittingType()),
-      ownsFitter(true) { /* Nothing to do. */ }
-
-  /**
-   * Create a GMM with the given means, covariances, and weights, and use the
-   * given initialized FittingType class.  This is useful in cases where the
-   * FittingType class needs to store some state.
-   *
-   * @param means Means of the model.
-   * @param covariances Covariances of the model.
-   * @param weights Weights of the model.
-   */
-  GMM(const std::vector<distribution::GaussianDistribution> & dists,
-      const arma::vec& weights,
-      FittingType& fitter) :
-      gaussians(dists.size()),
-      dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
-      dists(dists),
-      weights(weights),
-      fitter(&fitter),
-      ownsFitter(false) { /* Nothing to do. */ }
-
-  /**
-   * Copy constructor for GMMs which use different fitting types.
-   */
-  template<typename OtherFittingType>
-  GMM(const GMM<OtherFittingType>& other);
+      weights(weights) { /* Nothing to do. */ }
 
-  /**
-   * Copy constructor for GMMs using the same fitting type.  This also copies
-   * the fitter.
-   */
+  //! Copy constructor for GMMs.
   GMM(const GMM& other);
 
-  //! Destructor to clean up memory if necessary.
-  ~GMM();
-
-  /**
-   * Copy operator for GMMs which use different fitting types.
-   */
-  template<typename OtherFittingType>
-  GMM& operator=(const GMM<OtherFittingType>& other);
-
-  /**
-   * Copy operator for GMMs which use the same fitting type.  This also copies
-   * the fitter.
-   */
+  //! Copy operator for GMMs.
   GMM& operator=(const GMM& other);
 
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
-  //! Modify the number of gaussians in the model.  Careful!  You will have to
-  //! resize the means, covariances, and weights yourself.
-  size_t& Gaussians() { return gaussians; }
-
   //! Return the dimensionality of the model.
   size_t Dimensionality() const { return dimensionality; }
-  //! Modify the dimensionality of the model.  Careful!  You will have to update
-  //! each mean and covariance matrix yourself.
-  size_t& Dimensionality() { return dimensionality; }
 
   /**
    * Return a const reference to a component distribution.
@@ -221,11 +152,6 @@ class GMM
   //! Return a reference to the a priori weights of each Gaussian.
   arma::vec& Weights() { return weights; }
 
-  //! Return a const reference to the fitting type.
-  const FittingType& Fitter() const { return fitter; }
-  //! Return a reference to the fitting type.
-  FittingType& Fitter() { return fitter; }
-
   /**
    * Return the probability that the given observation came from this
    * distribution.
@@ -274,9 +200,11 @@ class GMM
    *      model for the estimation.
    * @return The log-likelihood of the best fit.
    */
+  template<typename FittingType = EMFit<>>
   double Train(const arma::mat& observations,
                const size_t trials = 1,
-               const bool useExistingModel = false);
+               const bool useExistingModel = false,
+               FittingType fitter = FittingType());
 
   /**
    * Estimate the probability distribution directly from the given observations,
@@ -302,10 +230,12 @@ class GMM
    *     model for the estimation.
    * @return The log-likelihood of the best fit.
    */
+  template<typename FittingType = EMFit<>>
   double Train(const arma::mat& observations,
                const arma::vec& probabilities,
                const size_t trials = 1,
-               const bool useExistingModel = false);
+               const bool useExistingModel = false,
+               FittingType fitter = FittingType());
 
   /**
    * Classify the given observations as being from an individual component in
diff --git a/src/mlpack/methods/gmm/gmm_impl.hpp b/src/mlpack/methods/gmm/gmm_impl.hpp
index fd322c5..97a0c85 100644
--- a/src/mlpack/methods/gmm/gmm_impl.hpp
+++ b/src/mlpack/methods/gmm/gmm_impl.hpp
@@ -16,169 +16,13 @@ namespace mlpack {
 namespace gmm {
 
 /**
- * Create a GMM with the given number of Gaussians, each of which have the
- * specified dimensionality.  The means and covariances will be set to 0.
- *
- * @param gaussians Number of Gaussians in this GMM.
- * @param dimensionality Dimensionality of each Gaussian.
- */
-template<typename FittingType>
-GMM<FittingType>::GMM(const size_t gaussians, const size_t dimensionality) :
-    gaussians(gaussians),
-    dimensionality(dimensionality),
-    dists(gaussians, distribution::GaussianDistribution(dimensionality)),
-    weights(gaussians),
-    fitter(new FittingType()),
-    ownsFitter(true)
-{
-  // Set equal weights.  Technically this model is still valid, but only barely.
-  weights.fill(1.0 / gaussians);
-}
-
-/**
- * Create a GMM with the given number of Gaussians, each of which have the
- * specified dimensionality.  Also, pass in an initialized FittingType class;
- * this is useful in cases where the FittingType class needs to store some
- * state.
- *
- * @param gaussians Number of Gaussians in this GMM.
- * @param dimensionality Dimensionality of each Gaussian.
- * @param fitter Initialized fitting mechanism.
- */
-template<typename FittingType>
-GMM<FittingType>::GMM(const size_t gaussians,
-                      const size_t dimensionality,
-                      FittingType& fitter) :
-    gaussians(gaussians),
-    dimensionality(dimensionality),
-    dists(gaussians, distribution::GaussianDistribution(dimensionality)),
-    weights(gaussians),
-    fitter(&fitter),
-    ownsFitter(false)
-{
-  // Set equal weights.  Technically this model is still valid, but only barely.
-  weights.fill(1.0 / gaussians);
-}
-
-
-// Copy constructor.
-template<typename FittingType>
-template<typename OtherFittingType>
-GMM<FittingType>::GMM(const GMM<OtherFittingType>& other) :
-    gaussians(other.gaussians),
-    dimensionality(other.dimensionality),
-    dists(other.dists),
-    weights(other.weights),
-    fitter(new FittingType()),
-    ownsFitter(true) { /* Nothing to do. */ }
-
-// Copy constructor for when the other GMM uses the same fitting type.
-template<typename FittingType>
-GMM<FittingType>::GMM(const GMM<FittingType>& other) :
-    gaussians(other.Gaussians()),
-    dimensionality(other.dimensionality),
-    dists(other.dists),
-    weights(other.weights),
-    fitter(new FittingType(*other.fitter)),
-    ownsFitter(true) { /* Nothing to do. */ }
-
-template<typename FittingType>
-GMM<FittingType>::~GMM()
-{
-  if (ownsFitter)
-    delete fitter;
-}
-
-template<typename FittingType>
-template<typename OtherFittingType>
-GMM<FittingType>& GMM<FittingType>::operator=(
-    const GMM<OtherFittingType>& other)
-{
-  gaussians = other.gaussians;
-  dimensionality = other.dimensionality;
-  dists = other.dists;
-  weights = other.weights;
-
-  return *this;
-}
-
-template<typename FittingType>
-GMM<FittingType>& GMM<FittingType>::operator=(const GMM<FittingType>& other)
-{
-  gaussians = other.gaussians;
-  dimensionality = other.dimensionality;
-  dists = other.dists;
-  weights = other.weights;
-
-  if (fitter && ownsFitter)
-    delete fitter;
-  fitter = new FittingType(*other.fitter);
-  ownsFitter = true;
-
-  return *this;
-}
-
-/**
- * Return the probability of the given observation being from this GMM.
- */
-template<typename FittingType>
-double GMM<FittingType>::Probability(const arma::vec& observation) const
-{
-  // Sum the probability for each Gaussian in our mixture (and we have to
-  // multiply by the prior for each Gaussian too).
-  double sum = 0;
-  for (size_t i = 0; i < gaussians; i++)
-    sum += weights[i] * dists[i].Probability(observation);
-
-  return sum;
-}
-
-/**
- * Return the probability of the given observation being from the given
- * component in the mixture.
- */
-template<typename FittingType>
-double GMM<FittingType>::Probability(const arma::vec& observation,
-                                     const size_t component) const
-{
-  // We are only considering one Gaussian component -- so we only need to call
-  // Probability() once.  We do consider the prior probability!
-  return weights[component] * dists[component].Probability(observation);
-}
-
-/**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- */
-template<typename FittingType>
-arma::vec GMM<FittingType>::Random() const
-{
-  // Determine which Gaussian it will be coming from.
-  double gaussRand = math::Random();
-  size_t gaussian = 0;
-
-  double sumProb = 0;
-  for (size_t g = 0; g < gaussians; g++)
-  {
-    sumProb += weights(g);
-    if (gaussRand <= sumProb)
-    {
-      gaussian = g;
-      break;
-    }
-  }
-
-  return trans(chol(dists[gaussian].Covariance())) *
-      arma::randn<arma::vec>(dimensionality) + dists[gaussian].Mean();
-}
-
-/**
  * Fit the GMM to the given observations.
  */
 template<typename FittingType>
-double GMM<FittingType>::Train(const arma::mat& observations,
-                               const size_t trials,
-                               const bool useExistingModel)
+double GMM::Train(const arma::mat& observations,
+                  const size_t trials,
+                  const bool useExistingModel,
+                  FittingType fitter)
 {
   double bestLikelihood; // This will be reported later.
 
@@ -187,8 +31,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
   {
     // Train the model.  The user will have been warned earlier if the GMM was
     // initialized with no parameters (0 gaussians, dimensionality of 0).
-    fitter->Estimate(observations, dists, weights,
-        useExistingModel);
+    fitter.Estimate(observations, dists, weights, useExistingModel);
     bestLikelihood = LogLikelihood(observations, dists, weights);
   }
   else
@@ -207,8 +50,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
 
     // We need to keep temporary copies.  We'll do the first training into the
     // actual model position, so that if it's the best we don't need to copy it.
-    fitter->Estimate(observations, dists, weights,
-        useExistingModel);
+    fitter.Estimate(observations, dists, weights, useExistingModel);
 
     bestLikelihood = LogLikelihood(observations, dists, weights);
 
@@ -227,8 +69,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
         weightsTrial = weightsOrig;
       }
 
-      fitter->Estimate(observations, distsTrial, weightsTrial,
-          useExistingModel);
+      fitter.Estimate(observations, distsTrial, weightsTrial, useExistingModel);
 
       // Check to see if the log-likelihood of this one is better.
       double newLikelihood = LogLikelihood(observations, distsTrial,
@@ -259,10 +100,11 @@ double GMM<FittingType>::Train(const arma::mat& observations,
  * probability of being from this distribution.
  */
 template<typename FittingType>
-double GMM<FittingType>::Train(const arma::mat& observations,
-                               const arma::vec& probabilities,
-                               const size_t trials,
-                               const bool useExistingModel)
+double GMM::Train(const arma::mat& observations,
+                  const arma::vec& probabilities,
+                  const size_t trials,
+                  const bool useExistingModel,
+                  FittingType fitter)
 {
   double bestLikelihood; // This will be reported later.
 
@@ -271,7 +113,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
   {
     // Train the model.  The user will have been warned earlier if the GMM was
     // initialized with no parameters (0 gaussians, dimensionality of 0).
-    fitter->Estimate(observations, probabilities, dists, weights,
+    fitter.Estimate(observations, probabilities, dists, weights,
         useExistingModel);
     bestLikelihood = LogLikelihood(observations, dists, weights);
   }
@@ -291,7 +133,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
 
     // We need to keep temporary copies.  We'll do the first training into the
     // actual model position, so that if it's the best we don't need to copy it.
-    fitter->Estimate(observations, probabilities, dists, weights,
+    fitter.Estimate(observations, probabilities, dists, weights,
         useExistingModel);
 
     bestLikelihood = LogLikelihood(observations, dists, weights);
@@ -312,8 +154,7 @@ double GMM<FittingType>::Train(const arma::mat& observations,
         weightsTrial = weightsOrig;
       }
 
-      fitter->Estimate(observations, distsTrial, weightsTrial,
-          useExistingModel);
+      fitter.Estimate(observations, distsTrial, weightsTrial, useExistingModel);
 
       // Check to see if the log-likelihood of this one is better.
       double newLikelihood = LogLikelihood(observations, distsTrial,
@@ -340,65 +181,10 @@ double GMM<FittingType>::Train(const arma::mat& observations,
 }
 
 /**
- * Classify the given observations as being from an individual component in this
- * GMM.
- */
-template<typename FittingType>
-void GMM<FittingType>::Classify(const arma::mat& observations,
-                                arma::Row<size_t>& labels) const
-{
-  // This is not the best way to do this!
-
-  // We should not have to fill this with values, because each one should be
-  // overwritten.
-  labels.set_size(observations.n_cols);
-  for (size_t i = 0; i < observations.n_cols; ++i)
-  {
-    // Find maximum probability component.
-    double probability = 0;
-    for (size_t j = 0; j < gaussians; ++j)
-    {
-      double newProb = Probability(observations.unsafe_col(i), j);
-      if (newProb >= probability)
-      {
-        probability = newProb;
-        labels[i] = j;
-      }
-    }
-  }
-}
-
-/**
- * Get the log-likelihood of this data's fit to the model.
- */
-template<typename FittingType>
-double GMM<FittingType>::LogLikelihood(
-    const arma::mat& data,
-    const std::vector<distribution::GaussianDistribution>& distsL,
-    const arma::vec& weightsL) const
-{
-  double loglikelihood = 0;
-  arma::vec phis;
-  arma::mat likelihoods(gaussians, data.n_cols);
-
-  for (size_t i = 0; i < gaussians; i++)
-  {
-    distsL[i].Probability(data, phis);
-    likelihoods.row(i) = weightsL(i) * trans(phis);
-  }
-
-  // Now sum over every point.
-  for (size_t j = 0; j < data.n_cols; j++)
-    loglikelihood += log(accu(likelihoods.col(j)));
-  return loglikelihood;
-}
-
-/**
  * Serialize the object.
  */
-template<typename FittingType>
 template<typename Archive>
-void GMM<FittingType>::Serialize(Archive& ar, const unsigned int /* version */)
+void GMM::Serialize(Archive& ar, const unsigned int /* version */)
 {
   using data::CreateNVP;
 
@@ -419,16 +205,6 @@ void GMM<FittingType>::Serialize(Archive& ar, const unsigned int /* version */)
   }
 
   ar & CreateNVP(weights, "weights");
-
-  if (Archive::is_loading::value)
-  {
-    if (fitter && ownsFitter)
-      delete fitter;
-
-    ownsFitter = true;
-  }
-
-  ar & CreateNVP(fitter, "fitter");
 }
 
 } // namespace gmm
diff --git a/src/mlpack/methods/gmm/gmm_main.cpp b/src/mlpack/methods/gmm/gmm_train_main.cpp
similarity index 100%
rename from src/mlpack/methods/gmm/gmm_main.cpp
rename to src/mlpack/methods/gmm/gmm_train_main.cpp
diff --git a/src/mlpack/methods/gmm/gmm_util.hpp b/src/mlpack/methods/gmm/gmm_util.hpp
deleted file mode 100644
index 0390883..0000000
--- a/src/mlpack/methods/gmm/gmm_util.hpp
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * @file gmm_util.hpp
- * @author Ryan Curtin
- *
- * Utility to save GMMs to files.
- */
-#ifndef __MLPACK_METHODS_GMM_GMM_UTIL_HPP
-#define __MLPACK_METHODS_GMM_GMM_UTIL_HPP
-
-namespace mlpack {
-namespace gmm {
-
-// Save a GMM to file using boost::serialization.
-// This does not save a type id, however.
-template<typename GMMType>
-void SaveGMM(GMMType& g, const std::string filename)
-{
-  using namespace boost::archive;
-
-  const std::string extension = data::Extension(filename);
-  std::ofstream ofs(filename);
-  if (extension == "xml")
-  {
-    xml_oarchive ar(ofs);
-    ar << data::CreateNVP(g, "gmm");
-  }
-  else if (extension == "bin")
-  {
-    binary_oarchive ar(ofs);
-    ar << data::CreateNVP(g, "gmm");
-  }
-  else if (extension == "txt")
-  {
-    text_oarchive ar(ofs);
-    ar << data::CreateNVP(g, "gmm");
-  }
-  else
-    Log::Fatal << "Unknown extension '" << extension << "' for GMM model file "
-        << "(known: 'xml', 'bin', 'txt')." << std::endl;
-}
-
-} // namespace gmm
-} // namespace mlpack
-
-#endif
diff --git a/src/mlpack/methods/hmm/hmm_train_main.cpp b/src/mlpack/methods/hmm/hmm_train_main.cpp
index 291305c..e7820a6 100644
--- a/src/mlpack/methods/hmm/hmm_train_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_train_main.cpp
@@ -333,7 +333,7 @@ int main(int argc, char** argv)
             << "be greater than or equal to 1." << endl;
 
       // Create HMM object.
-      HMM<GMM<>> hmm(size_t(states), GMM<>(size_t(gaussians), dimensionality),
+      HMM<GMM> hmm(size_t(states), GMM(size_t(gaussians), dimensionality),
           tolerance);
 
       // Issue a warning if the user didn't give labels.
diff --git a/src/mlpack/methods/hmm/hmm_util_impl.hpp b/src/mlpack/methods/hmm/hmm_util_impl.hpp
index 4f34334..4c49785 100644
--- a/src/mlpack/methods/hmm/hmm_util_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_util_impl.hpp
@@ -79,7 +79,7 @@ void LoadHMMAndPerformActionHelper(const std::string& modelFile,
 
     case HMMType::GaussianMixtureModelHMM:
       DeserializeHMMAndPerformAction<ActionType, ArchiveType,
-          HMM<gmm::GMM<>>>(ar, x);
+          HMM<gmm::GMM>>(ar, x);
       break;
 
     default:
@@ -159,7 +159,7 @@ char GetHMMType<HMM<distribution::GaussianDistribution>>()
 }
 
 template<>
-char GetHMMType<HMM<gmm::GMM<>>>()
+char GetHMMType<HMM<gmm::GMM>>()
 {
   return HMMType::GaussianMixtureModelHMM;
 }
diff --git a/src/mlpack/tests/gmm_test.cpp b/src/mlpack/tests/gmm_test.cpp
index 603313c..628e03f 100644
--- a/src/mlpack/tests/gmm_test.cpp
+++ b/src/mlpack/tests/gmm_test.cpp
@@ -27,7 +27,7 @@ BOOST_AUTO_TEST_SUITE(GMMTest);
 BOOST_AUTO_TEST_CASE(GMMProbabilityTest)
 {
   // Create a GMM.
-  GMM<> gmm(2, 2);
+  GMM gmm(2, 2);
   gmm.Component(0) = distribution::GaussianDistribution("0 0", "1 0; 0 1");
   gmm.Component(1) = distribution::GaussianDistribution("3 3", "2 1; 1 2");
   gmm.Weights() = "0.3 0.7";
@@ -48,7 +48,7 @@ BOOST_AUTO_TEST_CASE(GMMProbabilityTest)
 BOOST_AUTO_TEST_CASE(GMMProbabilityComponentTest)
 {
   // Create a GMM (same as the last test).
-  GMM<> gmm(2, 2);
+  GMM gmm(2, 2);
   gmm.Component(0) = distribution::GaussianDistribution("0 0", "1 0; 0 1");
   gmm.Component(1) = distribution::GaussianDistribution("3 3", "2 1; 1 2");
   gmm.Weights() = "0.3 0.7";
@@ -99,7 +99,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMOneGaussian)
     data.row(1) += mean(1);
 
     // Now, train the model.
-    GMM<> gmm(1, 2);
+    GMM gmm(1, 2);
     gmm.Train(data, 10);
 
     arma::vec actualMean = arma::mean(data, 1);
@@ -193,7 +193,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussians)
     weights[i] = (double) counts[i] / data.n_cols;
 
   // Now train the model.
-  GMM<> gmm(gaussians, dims);
+  GMM gmm(gaussians, dims);
   gmm.Train(data, 10);
 
   arma::uvec sortRef = sort_index(weights);
@@ -236,7 +236,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMSingleGaussianWithProbability)
   probabilities.randu(20000); // Random probabilities.
 
   // Now train the model.
-  GMM<> g(1, 2);
+  GMM g(1, 2);
   g.Train(observations, probabilities, 10);
 
   // Check that it is trained correctly.  5% tolerance because of random error
@@ -306,7 +306,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
   }
 
   // Now train the model.
-  GMM<> g(4, 3); // 3 dimensions, 4 components.
+  GMM g(4, 3); // 3 dimensions, 4 components.
 
   g.Train(points, probabilities, 8);
 
@@ -374,7 +374,7 @@ BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
 BOOST_AUTO_TEST_CASE(GMMRandomTest)
 {
   // Simple GMM distribution.
-  GMM<> gmm(2, 2);
+  GMM gmm(2, 2);
   gmm.Weights() = arma::vec("0.40 0.60");
 
   // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
@@ -392,7 +392,7 @@ BOOST_AUTO_TEST_CASE(GMMRandomTest)
     observations.col(i) = gmm.Random();
 
   // A new one which we'll train.
-  GMM<> gmm2(2, 2);
+  GMM gmm2(2, 2);
   gmm2.Train(observations, 10);
 
   // Now check the results.  We need to order by weights so that when we do the
@@ -439,7 +439,7 @@ BOOST_AUTO_TEST_CASE(GMMRandomTest)
 BOOST_AUTO_TEST_CASE(GMMClassifyTest)
 {
   // First create a Gaussian with a few components.
-  GMM<> gmm(3, 2);
+  GMM gmm(3, 2);
   gmm.Component(0) = distribution::GaussianDistribution("0 0", "1 0; 0 1");
   gmm.Component(1) = distribution::GaussianDistribution("1 3", "3 2; 2 3");
   gmm.Component(2) = distribution::GaussianDistribution("-2 -2",
@@ -484,7 +484,7 @@ BOOST_AUTO_TEST_CASE(GMMClassifyTest)
 BOOST_AUTO_TEST_CASE(GMMLoadSaveTest)
 {
   // Create a GMM, save it, and load it.
-  GMM<> gmm(10, 4);
+  GMM gmm(10, 4);
   gmm.Weights().randu();
 
   for (size_t i = 0; i < gmm.Gaussians(); ++i)
@@ -506,7 +506,7 @@ BOOST_AUTO_TEST_CASE(GMMLoadSaveTest)
   }
 
   // Load the GMM.
-  GMM<> gmm2;
+  GMM gmm2;
   {
     std::ifstream ifs("test-gmm-save.xml");
     boost::archive::xml_iarchive ar(ifs);
@@ -682,10 +682,10 @@ BOOST_AUTO_TEST_CASE(UseExistingModelTest)
     weights[i] = (double) counts[i] / data.n_cols;
 
   // Now train the model.
-  GMM<> gmm(gaussians, dims);
+  GMM gmm(gaussians, dims);
   gmm.Train(data, 10);
 
-  GMM<> oldgmm(gmm);
+  GMM oldgmm(gmm);
 
   // Retrain the model with the existing model as the starting point.
   gmm.Train(data, 1, true);
diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp
index e4b862a..d4591df 100644
--- a/src/mlpack/tests/hmm_test.cpp
+++ b/src/mlpack/tests/hmm_test.cpp
@@ -772,8 +772,8 @@ BOOST_AUTO_TEST_CASE(GaussianHMMGenerateTest)
 BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
 {
   // We will use two GMMs; one with two components and one with three.
-  std::vector<GMM<> > gmms(2);
-  gmms[0] = GMM<>(2, 2);
+  std::vector<GMM> gmms(2);
+  gmms[0] = GMM(2, 2);
   gmms[0].Weights() = arma::vec("0.75 0.25");
 
   // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
@@ -784,7 +784,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
   gmms[0].Component(1) = GaussianDistribution("7.10 5.01",
                                               "1.00 0.00; 0.00 1.01");
 
-  gmms[1] = GMM<>(3, 2);
+  gmms[1] = GMM(3, 2);
   gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
 
   gmms[1].Component(0) = GaussianDistribution("-3.00 -6.12",
@@ -804,7 +804,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
                   "0.70 0.50");
 
   // Now build the model.
-  HMM<GMM<> > hmm(initial, trans, gmms);
+  HMM<GMM> hmm(initial, trans, gmms);
 
   // Make a sequence of observations.
   arma::mat observations(2, 1000);
@@ -842,7 +842,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
   srand(time(NULL));
 
   // We will use two GMMs; one with two components and one with three.
-  std::vector<GMM<> > gmms(2, GMM<>(2, 2));
+  std::vector<GMM> gmms(2, GMM(2, 2));
   gmms[0].Weights() = arma::vec("0.3 0.7");
 
   // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
@@ -887,7 +887,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
   }
 
   // Set up the GMM for training.
-  HMM<GMM<> > hmm(2, GMM<>(2, 2));
+  HMM<GMM> hmm(2, GMM(2, 2));
 
   // Train the HMM.
   hmm.Train(observations, states);
@@ -984,7 +984,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
 BOOST_AUTO_TEST_CASE(GMMHMMLoadSaveTest)
 {
   // Create a GMM HMM, save it, and load it.
-  HMM<GMM<> > hmm(3, GMM<>(4, 3));
+  HMM<GMM> hmm(3, GMM(4, 3));
 
   for(size_t j = 0; j < hmm.Emission().size(); ++j)
   {
@@ -1009,7 +1009,7 @@ BOOST_AUTO_TEST_CASE(GMMHMMLoadSaveTest)
   }
 
   // Load the HMM.
-  HMM<GMM<>> hmm2(3, GMM<>(4, 3));
+  HMM<GMM> hmm2(3, GMM(4, 3));
   {
     std::ifstream ifs("test-hmm-save.xml");
     boost::archive::xml_iarchive ar(ifs);



More information about the mlpack-git mailing list