[mlpack-svn] r14804 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 1 20:58:15 EDT 2013
Author: rcurtin
Date: 2013-04-01 20:58:15 -0400 (Mon, 01 Apr 2013)
New Revision: 14804
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
Log:
Add Load() and Save() functionality for GMMs.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp 2013-04-02 00:58:01 UTC (rev 14803)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp 2013-04-02 00:58:15 UTC (rev 14804)
@@ -200,6 +200,21 @@
*/
GMM& operator=(const GMM& other);
+ /**
+ * Load a GMM from an XML file. The format of the XML file should be the same
+ * as is generated by the Save() method.
+ *
+ * @param filename Name of XML file containing model to be loaded.
+ */
+ void Load(const std::string& filename);
+
+ /**
+ * Save a GMM to an XML file.
+ *
+ * @param filename Name of XML file to write to.
+ */
+ void Save(const std::string& filename) const;
+
//! Return the number of gaussians in the model.
size_t Gaussians() const { return gaussians; }
//! Modify the number of gaussians in the model. Careful! You will have to
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp 2013-04-02 00:58:01 UTC (rev 14803)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_impl.hpp 2013-04-02 00:58:15 UTC (rev 14804)
@@ -11,6 +11,8 @@
// In case it hasn't already been included.
#include "gmm.hpp"
+#include <mlpack/core/util/save_restore_utility.hpp>
+
namespace mlpack {
namespace gmm {
@@ -64,6 +66,67 @@
return *this;
}
+// Load a GMM from file.
+template<typename FittingType>
+void GMM<FittingType>::Load(const std::string& filename)
+{
+ util::SaveRestoreUtility load;
+
+ if (!load.ReadFile(filename))
+ Log::Fatal << "GMM::Load(): could not read file '" << filename << "'!\n";
+
+ load.LoadParameter(gaussians, "gaussians");
+ load.LoadParameter(dimensionality, "dimensionality");
+ load.LoadParameter(weights, "weights");
+
+ // We need to do a little error checking here.
+ if (weights.n_elem != gaussians)
+ {
+ Log::Fatal << "GMM::Load('" << filename << "'): file reports " << gaussians
+ << " gaussians but weights vector only contains " << weights.n_elem
+ << " elements!" << std::endl;
+ }
+
+ means.resize(gaussians);
+ covariances.resize(gaussians);
+
+ for (size_t i = 0; i < gaussians; ++i)
+ {
+ std::stringstream o;
+ o << i;
+ std::string meanName = "mean" + o.str();
+ std::string covName = "covariance" + o.str();
+
+ load.LoadParameter(means[i], meanName);
+ load.LoadParameter(covariances[i], covName);
+ }
+}
+
+// Save a GMM to a file.
+template<typename FittingType>
+void GMM<FittingType>::Save(const std::string& filename) const
+{
+ util::SaveRestoreUtility save;
+ save.SaveParameter(gaussians, "gaussians");
+ save.SaveParameter(dimensionality, "dimensionality");
+ save.SaveParameter(weights, "weights");
+ for (size_t i = 0; i < gaussians; ++i)
+ {
+ // Generate names for the XML nodes.
+ std::stringstream o;
+ o << i;
+ std::string meanName = "mean" + o.str();
+ std::string covName = "covariance" + o.str();
+
+ // Now save them.
+ save.SaveParameter(means[i], meanName);
+ save.SaveParameter(covariances[i], covName);
+ }
+
+ if (!save.WriteFile(filename))
+ Log::Warn << "GMM::Save(): error saving to '" << filename << "'.\n";
+}
+
/**
* Return the probability of the given observation being from this GMM.
*/
More information about the mlpack-svn
mailing list