[mlpack-git] master: Add Serialize() to GMM and refactor to not use references. This is because references make serialization really hard. As a second though, it may actually be better to not hold the FittingType internally but only pass it when the model is being trained. Needs more thought (and someone to make the change). (076afb4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:04:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc

>---------------------------------------------------------------

commit 076afb446f892a481f41db091d1473f9328300f8
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Jul 12 13:29:40 2015 +0000

    Add Serialize() to GMM and refactor to not use references.
    This is because references make serialization really hard.
    As a second though, it may actually be better to not hold the FittingType internally but only pass it when the model is being trained.  Needs more thought (and someone to make the change).


>---------------------------------------------------------------

076afb446f892a481f41db091d1473f9328300f8
 src/mlpack/methods/gmm/gmm.hpp      |  73 +++++------------
 src/mlpack/methods/gmm/gmm_impl.hpp | 155 +++++++++++++++---------------------
 2 files changed, 87 insertions(+), 141 deletions(-)

diff --git a/src/mlpack/methods/gmm/gmm.hpp b/src/mlpack/methods/gmm/gmm.hpp
index fab8974..d34b232 100644
--- a/src/mlpack/methods/gmm/gmm.hpp
+++ b/src/mlpack/methods/gmm/gmm.hpp
@@ -69,7 +69,7 @@ namespace gmm /** Gaussian Mixture Models. */ {
  * arma::vec observation = g.Random();
  * @endcode
  */
-template<typename FittingType = EMFit<> >
+template<typename FittingType = EMFit<>>
 class GMM
 {
  private:
@@ -81,13 +81,14 @@ class GMM
   //! Vector of Gaussians
   std::vector<distribution::GaussianDistribution> dists;
 
-  //! Legacy member data, not used.
-  std::vector<arma::vec> means;
-  std::vector<arma::mat> covariances;
-
   //! Vector of a priori weights for each Gaussian.
   arma::vec weights;
 
+  //! Locally-stored fitting object; in case the user did not pass one.
+  FittingType* fitter;
+  //! Whether or not we own the fitter.
+  bool ownsFitter;
+
  public:
   /**
    * Create an empty Gaussian Mixture Model, with zero gaussians.
@@ -95,8 +96,8 @@ class GMM
   GMM() :
       gaussians(0),
       dimensionality(0),
-      localFitter(FittingType()),
-      fitter(localFitter)
+      fitter(new FittingType()),
+      ownsFitter(true)
   {
     // Warn the user.  They probably don't want to do this.  If this constructor
     // is being used (because it is required by some template classes), the user
@@ -140,8 +141,8 @@ class GMM
       dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
       dists(dists),
       weights(weights),
-      localFitter(FittingType()),
-      fitter(localFitter) { /* Nothing to do. */ }
+      fitter(new FittingType()),
+      ownsFitter(true) { /* Nothing to do. */ }
 
   /**
    * Create a GMM with the given means, covariances, and weights, and use the
@@ -159,7 +160,8 @@ class GMM
       dimensionality((!dists.empty()) ? dists[0].Mean().n_elem : 0),
       dists(dists),
       weights(weights),
-      fitter(fitter) { /* Nothing to do. */ }
+      fitter(&fitter),
+      ownsFitter(false) { /* Nothing to do. */ }
 
   /**
    * Copy constructor for GMMs which use different fitting types.
@@ -173,6 +175,9 @@ class GMM
    */
   GMM(const GMM& other);
 
+  //! Destructor to clean up memory if necessary.
+  ~GMM();
+
   /**
    * Copy operator for GMMs which use different fitting types.
    */
@@ -185,36 +190,6 @@ class GMM
    */
   GMM& operator=(const GMM& other);
 
-  /**
-   * Load a GMM from an XML file.  The format of the XML file should be the same
-   * as is generated by the Save() method.
-   *
-   * @param filename Name of XML file containing model to be loaded.
-   */
-  void Load(const std::string& filename);
-
-  /**
-   * Save a GMM to an XML file.
-   *
-   * @param filename Name of XML file to write to.
-   */
-  void Save(const std::string& filename) const;
-
-  /**
-   * Load a GMM from a SaveRestoreUtility.  The format should be the same
-   * as is generated by the Save() method.
-   *
-   * @param filename Name of SaveRestoreUtility containing model to be loaded.
-   */
-  void Load(const util::SaveRestoreUtility& sr);
-
-  /**
-   * Save a GMM to a SaveRestoreUtility.
-   *
-   * @param SaveRestoreUtility object to save to.
-   */
-  void Save(util::SaveRestoreUtility& sr) const;
-
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
   //! Modify the number of gaussians in the model.  Careful!  You will have to
@@ -357,9 +332,10 @@ class GMM
   std::string ToString() const;
 
   /**
-   * Returns a string indicating the type.
+   * Serialize the GMM.
    */
-  static std::string const Type() { return "GMM"; }
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
 
  private:
   /**
@@ -371,15 +347,10 @@ class GMM
    * @param covars Covariances of the given mixture model.
    * @param weights Weights of the given mixture model.
    */
-  double LogLikelihood(const arma::mat& dataPoints,
-                       const std::vector<distribution::GaussianDistribution>& distsL,
-                       const arma::vec& weights) const;
-
-  //! Locally-stored fitting object; in case the user did not pass one.
-  FittingType localFitter;
-
-  //! Reference to the fitting object we should use.
-  FittingType& fitter;
+  double LogLikelihood(
+      const arma::mat& dataPoints,
+      const std::vector<distribution::GaussianDistribution>& distsL,
+      const arma::vec& weights) const;
 };
 
 }; // namespace gmm
diff --git a/src/mlpack/methods/gmm/gmm_impl.hpp b/src/mlpack/methods/gmm/gmm_impl.hpp
index c75b2d3..d1eccad 100644
--- a/src/mlpack/methods/gmm/gmm_impl.hpp
+++ b/src/mlpack/methods/gmm/gmm_impl.hpp
@@ -12,8 +12,6 @@
 // In case it hasn't already been included.
 #include "gmm.hpp"
 
-#include <mlpack/core/util/save_restore_utility.hpp>
-
 namespace mlpack {
 namespace gmm {
 
@@ -30,8 +28,8 @@ GMM<FittingType>::GMM(const size_t gaussians, const size_t dimensionality) :
     dimensionality(dimensionality),
     dists(gaussians, distribution::GaussianDistribution(dimensionality)),
     weights(gaussians),
-    localFitter(FittingType()),
-    fitter(localFitter)
+    fitter(new FittingType()),
+    ownsFitter(true)
 {
   // Set equal weights.  Technically this model is still valid, but only barely.
   weights.fill(1.0 / gaussians);
@@ -55,7 +53,8 @@ GMM<FittingType>::GMM(const size_t gaussians,
     dimensionality(dimensionality),
     dists(gaussians, distribution::GaussianDistribution(dimensionality)),
     weights(gaussians),
-    fitter(fitter)
+    fitter(&fitter),
+    ownsFitter(false)
 {
   // Set equal weights.  Technically this model is still valid, but only barely.
   weights.fill(1.0 / gaussians);
@@ -70,8 +69,8 @@ GMM<FittingType>::GMM(const GMM<OtherFittingType>& other) :
     dimensionality(other.dimensionality),
     dists(other.dists),
     weights(other.weights),
-    localFitter(FittingType()),
-    fitter(localFitter) { /* Nothing to do. */ }
+    fitter(new FittingType()),
+    ownsFitter(true) { /* Nothing to do. */ }
 
 // Copy constructor for when the other GMM uses the same fitting type.
 template<typename FittingType>
@@ -80,8 +79,15 @@ GMM<FittingType>::GMM(const GMM<FittingType>& other) :
     dimensionality(other.dimensionality),
     dists(other.dists),
     weights(other.weights),
-    localFitter(other.fitter),
-    fitter(localFitter) { /* Nothing to do. */ }
+    fitter(new FittingType(*other.fitter)),
+    ownsFitter(true) { /* Nothing to do. */ }
+
+template<typename FittingType>
+GMM<FittingType>::~GMM()
+{
+  if (ownsFitter)
+    delete fitter;
+}
 
 template<typename FittingType>
 template<typename OtherFittingType>
@@ -103,81 +109,13 @@ GMM<FittingType>& GMM<FittingType>::operator=(const GMM<FittingType>& other)
   dimensionality = other.dimensionality;
   dists = other.dists;
   weights = other.weights;
-  localFitter = other.fitter;
-
-  return *this;
-}
-
-// Load a GMM from file.
-template<typename FittingType>
-void GMM<FittingType>::Load(const std::string& filename)
-{
-  util::SaveRestoreUtility load;
-
-  if (!load.ReadFile(filename))
-    Log::Fatal << "GMM::Load(): could not read file '" << filename << "'!\n";
-  Load(load);
-}
-
-// Save a GMM to a file.
-template<typename FittingType>
-void GMM<FittingType>::Save(const std::string& filename) const
-{
-  util::SaveRestoreUtility save;
-  Save(save);
-
-  if (!save.WriteFile(filename))
-    Log::Warn << "GMM::Save(): error saving to '" << filename << "'.\n";
-}
-
-
-// Save a GMM to a SaveRestoreUtility.
-template<typename FittingType>
-void GMM<FittingType>::Save(util::SaveRestoreUtility& sr) const
-{
-  sr.SaveParameter(Type(), "type");
-  sr.SaveParameter(gaussians, "gaussians");
-  sr.SaveParameter(dimensionality, "dimensionality");
-  sr.SaveParameter(weights, "weights");
 
-  util::SaveRestoreUtility child;
-  for (size_t i = 0; i < gaussians; ++i)
-  {
-    // Generate names for the XML nodes.
-    std::stringstream o;
-    o << i;
-    std::string gaussianName = "gaussian" + o.str();
-
-    // Now save them.
-    dists[i].Save(child);
-    sr.AddChild(child, gaussianName);
-  }
-}
+  if (fitter && ownsFitter)
+    delete fitter;
+  fitter = new FittingType(other.fitter);
+  ownsFitter = true;
 
-// Load a GMM from SaveRestoreUtility.
-template<typename FittingType>
-void GMM<FittingType>::Load(const util::SaveRestoreUtility& sr)
-{
-    sr.LoadParameter(gaussians, "gaussians");
-    sr.LoadParameter(dimensionality, "dimensionality");
-    sr.LoadParameter(weights, "weights");
-
-    // We need to do a little error checking here.
-    if (weights.n_elem != gaussians)
-    {
-      Log::Fatal << "GMM::Load reports " << gaussians
-      << " gaussians but weights vector only contains " << weights.n_elem
-      << " elements!" << std::endl;
-    }
-
-    dists.resize(gaussians);
-
-    for (size_t i = 0; i < gaussians; ++i)
-    {
-      std::stringstream o;
-      o << "gaussian" << i;
-      dists[i].Load(sr.Children().at(o.str()));
-    }
+  return *this;
 }
 
 /**
@@ -249,7 +187,7 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
   {
     // Train the model.  The user will have been warned earlier if the GMM was
     // initialized with no parameters (0 gaussians, dimensionality of 0).
-    fitter.Estimate(observations, dists, weights,
+    fitter->Estimate(observations, dists, weights,
         useExistingModel);
     bestLikelihood = LogLikelihood(observations, dists, weights);
   }
@@ -269,7 +207,7 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
 
     // We need to keep temporary copies.  We'll do the first training into the
     // actual model position, so that if it's the best we don't need to copy it.
-    fitter.Estimate(observations, dists, weights,
+    fitter->Estimate(observations, dists, weights,
         useExistingModel);
 
     bestLikelihood = LogLikelihood(observations, dists, weights);
@@ -290,7 +228,8 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
         weightsTrial = weightsOrig;
       }
 
-      fitter.Estimate(observations, distsTrial, weightsTrial, useExistingModel);
+      fitter->Estimate(observations, distsTrial, weightsTrial,
+          useExistingModel);
 
       // Check to see if the log-likelihood of this one is better.
       double newLikelihood = LogLikelihood(observations, distsTrial,
@@ -333,7 +272,7 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
   {
     // Train the model.  The user will have been warned earlier if the GMM was
     // initialized with no parameters (0 gaussians, dimensionality of 0).
-    fitter.Estimate(observations, probabilities, dists, weights,
+    fitter->Estimate(observations, probabilities, dists, weights,
         useExistingModel);
     bestLikelihood = LogLikelihood(observations, dists, weights);
   }
@@ -353,7 +292,7 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
 
     // We need to keep temporary copies.  We'll do the first training into the
     // actual model position, so that if it's the best we don't need to copy it.
-    fitter.Estimate(observations, probabilities, dists, weights,
+    fitter->Estimate(observations, probabilities, dists, weights,
         useExistingModel);
 
     bestLikelihood = LogLikelihood(observations, dists, weights);
@@ -374,7 +313,7 @@ double GMM<FittingType>::Estimate(const arma::mat& observations,
         weightsTrial = weightsOrig;
       }
 
-      fitter.Estimate(observations, distsTrial, weightsTrial,
+      fitter->Estimate(observations, distsTrial, weightsTrial,
           useExistingModel);
 
       // Check to see if the log-likelihood of this one is better.
@@ -484,9 +423,45 @@ std::string GMM<FittingType>::ToString() const
   return convert.str();
 }
 
+/**
+ * Serialize the object.
+ */
+template<typename FittingType>
+template<typename Archive>
+void GMM<FittingType>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(gaussians, "gaussians");
+  ar & CreateNVP(dimensionality, "dimensionality");
+
+  // Load (or save) the gaussians.  Not going to use the default std::vector
+  // serialize here because it won't call out correctly to Serialize() for each
+  // Gaussian distribution.
+  if (Archive::is_loading::value)
+    dists.resize(gaussians);
+  for (size_t i = 0; i < gaussians; ++i)
+  {
+    std::ostringstream oss;
+    oss << "dist" << i;
+    ar & CreateNVP(dists[i], oss.str());
+  }
+
+  ar & CreateNVP(weights, "weights");
+
+  if (Archive::is_loading::value)
+  {
+    if (fitter && ownsFitter)
+      delete fitter;
+
+    ownsFitter = true;
+  }
+
+  ar & CreateNVP(fitter, "fitter");
+}
 
-}; // namespace gmm
-}; // namespace mlpack
+} // namespace gmm
+} // namespace mlpack
 
 #endif
 



More information about the mlpack-git mailing list