[mlpack-svn] r14804 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 1 20:58:15 EDT 2013


Author: rcurtin
Date: 2013-04-01 20:58:15 -0400 (Mon, 01 Apr 2013)
New Revision: 14804

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
Log:
Add Load() and Save() functionality for GMMs.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2013-04-02 00:58:01 UTC (rev 14803)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2013-04-02 00:58:15 UTC (rev 14804)
@@ -200,6 +200,21 @@
    */
   GMM& operator=(const GMM& other);
 
+  /**
+   * Load a GMM from an XML file.  The format of the XML file should be the same
+   * as is generated by the Save() method.
+   *
+   * @param filename Name of XML file containing model to be loaded.
+   */
+  void Load(const std::string& filename);
+
+  /**
+   * Save a GMM to an XML file.
+   *
+   * @param filename Name of XML file to write to.
+   */
+  void Save(const std::string& filename) const;
+
   //! Return the number of gaussians in the model.
   size_t Gaussians() const { return gaussians; }
   //! Modify the number of gaussians in the model.  Careful!  You will have to

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	2013-04-02 00:58:01 UTC (rev 14803)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp	2013-04-02 00:58:15 UTC (rev 14804)
@@ -11,6 +11,8 @@
 // In case it hasn't already been included.
 #include "gmm.hpp"
 
+#include <mlpack/core/util/save_restore_utility.hpp>
+
 namespace mlpack {
 namespace gmm {
 
@@ -64,6 +66,67 @@
   return *this;
 }
 
+// Load a GMM from file.
+template<typename FittingType>
+void GMM<FittingType>::Load(const std::string& filename)
+{
+  util::SaveRestoreUtility load;
+
+  if (!load.ReadFile(filename))
+    Log::Fatal << "GMM::Load(): could not read file '" << filename << "'!\n";
+
+  load.LoadParameter(gaussians, "gaussians");
+  load.LoadParameter(dimensionality, "dimensionality");
+  load.LoadParameter(weights, "weights");
+
+  // We need to do a little error checking here.
+  if (weights.n_elem != gaussians)
+  {
+    Log::Fatal << "GMM::Load('" << filename << "'): file reports " << gaussians
+        << " gaussians but weights vector only contains " << weights.n_elem
+        << " elements!" << std::endl;
+  }
+
+  means.resize(gaussians);
+  covariances.resize(gaussians);
+
+  for (size_t i = 0; i < gaussians; ++i)
+  {
+    std::stringstream o;
+    o << i;
+    std::string meanName = "mean" + o.str();
+    std::string covName = "covariance" + o.str();
+
+    load.LoadParameter(means[i], meanName);
+    load.LoadParameter(covariances[i], covName);
+  }
+}
+
+// Save a GMM to a file.
+template<typename FittingType>
+void GMM<FittingType>::Save(const std::string& filename) const
+{
+  util::SaveRestoreUtility save;
+  save.SaveParameter(gaussians, "gaussians");
+  save.SaveParameter(dimensionality, "dimensionality");
+  save.SaveParameter(weights, "weights");
+  for (size_t i = 0; i < gaussians; ++i)
+  {
+    // Generate names for the XML nodes.
+    std::stringstream o;
+    o << i;
+    std::string meanName = "mean" + o.str();
+    std::string covName = "covariance" + o.str();
+
+    // Now save them.
+    save.SaveParameter(means[i], meanName);
+    save.SaveParameter(covariances[i], covName);
+  }
+
+  if (!save.WriteFile(filename))
+    Log::Warn << "GMM::Save(): error saving to '" << filename << "'.\n";
+}
+
 /**
  * Return the probability of the given observation being from this GMM.
  */




More information about the mlpack-svn mailing list