[mlpack-svn] r10879 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 17 02:19:02 EST 2011


Author: rcurtin
Date: 2011-12-17 02:19:01 -0500 (Sat, 17 Dec 2011)
New Revision: 10879

Added:
   mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp
   mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp
Log:
Utilities for loading/saving HMMs; to be deprecated later.


Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp	2011-12-17 07:19:01 UTC (rev 10879)
@@ -0,0 +1,42 @@
+/**
+ * @file hmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Save/load utilities for HMMs.  This should be eventually merged into the HMM
+ * class itself.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+
+#include "hmm.hpp"
+
+namespace mlpack {
+namespace hmm {
+
+/**
+ * Save an HMM to file.  This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr);
+
+/**
+ * Load an HMM from file.  This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr);
+
+}; // namespace hmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "hmm_util_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp	2011-12-17 07:19:01 UTC (rev 10879)
@@ -0,0 +1,240 @@
+/**
+ * @file hmm_util_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of HMM load/save functions.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr)
+{
+  Log::Fatal << "HMM save not implemented for arbitrary distributions."
+      << std::endl;
+}
+
+template<>
+void SaveHMM(const HMM<distribution::DiscreteDistribution>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type = "discrete";
+  size_t states = hmm.Transition().n_rows;
+
+  sr.SaveParameter(type, "hmm_type");
+  sr.SaveParameter(states, "hmm_states");
+  sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+  // Now the emissions.
+  for (size_t i = 0; i < states; ++i)
+  {
+    // Generate name.
+    std::stringstream s;
+    s << "hmm_emission_distribution_" << i;
+    sr.SaveParameter(hmm.Emission()[i].Probabilities(), s.str());
+  }
+}
+
+template<>
+void SaveHMM(const HMM<distribution::GaussianDistribution>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type = "gaussian";
+  size_t states = hmm.Transition().n_rows;
+
+  sr.SaveParameter(type, "hmm_type");
+  sr.SaveParameter(states, "hmm_states");
+  sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+  // Now the emissions.
+  for (size_t i = 0; i < states; ++i)
+  {
+    // Generate name.
+    std::stringstream s;
+    s << "hmm_emission_mean_" << i;
+    sr.SaveParameter(hmm.Emission()[i].Mean(), s.str());
+
+    s.str("");
+    s << "hmm_emission_covariance_" << i;
+    sr.SaveParameter(hmm.Emission()[i].Covariance(), s.str());
+  }
+}
+
+template<>
+void SaveHMM(const HMM<gmm::GMM>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type = "gmm";
+  size_t states = hmm.Transition().n_rows;
+
+  sr.SaveParameter(type, "hmm_type");
+  sr.SaveParameter(states, "hmm_states");
+  sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+  // Now the emissions.
+  for (size_t i = 0; i < states; ++i)
+  {
+    // Generate name.
+    std::stringstream s;
+    s << "hmm_emission_" << i << "_gaussians";
+    sr.SaveParameter(hmm.Emission()[i].Gaussians(), s.str());
+
+    s.str("");
+    s << "hmm_emission_" << i << "_weights";
+    sr.SaveParameter(hmm.Emission()[i].Weights(), s.str());
+
+    for (size_t g = 0; g < hmm.Emission()[i].Gaussians(); ++g)
+    {
+      s.str("");
+      s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+      sr.SaveParameter(hmm.Emission()[i].Means()[g], s.str());
+
+      s.str("");
+      s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+      sr.SaveParameter(hmm.Emission()[i].Covariances()[g], s.str());
+    }
+  }
+}
+
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr)
+{
+  Log::Fatal << "HMM load not implemented for arbitrary distributions."
+      << std::endl;
+}
+
+template<>
+void LoadHMM(HMM<distribution::DiscreteDistribution>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type;
+  size_t states;
+
+  sr.LoadParameter(type, "hmm_type");
+  if (type != "discrete")
+  {
+    Log::Fatal << "Cannot load non-discrete HMM (of type " << type << ") as "
+        << "discrete HMM!" << std::endl;
+  }
+
+  sr.LoadParameter(states, "hmm_states");
+
+  // Load transition matrix.
+  sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+  // Now each emission distribution.
+  hmm.Emission().resize(states);
+  for (size_t i = 0; i < states; ++i)
+  {
+    std::stringstream s;
+    s << "hmm_emission_distribution_" << i;
+    sr.LoadParameter(hmm.Emission()[i].Probabilities(), s.str());
+  }
+
+  hmm.Dimensionality() = 1;
+}
+
+template<>
+void LoadHMM(HMM<distribution::GaussianDistribution>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type;
+  size_t states;
+
+  sr.LoadParameter(type, "hmm_type");
+  if (type != "gaussian")
+  {
+    Log::Fatal << "Cannot load non-Gaussian HMM (of type " << type << ") as "
+        << "a Gaussian HMM!" << std::endl;
+  }
+
+  sr.LoadParameter(states, "hmm_states");
+
+  // Load transition matrix.
+  sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+  // Now each emission distribution.
+  hmm.Emission().resize(states);
+  for (size_t i = 0; i < states; ++i)
+  {
+    std::stringstream s;
+    s << "hmm_emission_mean_" << i;
+    sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());
+
+    s.str("");
+    s << "hmm_emission_covariance_" << i;
+    sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
+  }
+
+  hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
+}
+
+template<>
+void LoadHMM(HMM<gmm::GMM>& hmm,
+             utilities::SaveRestoreUtility& sr)
+{
+  std::string type;
+  size_t states;
+
+  sr.LoadParameter(type, "hmm_type");
+  if (type != "gmm")
+  {
+    Log::Fatal << "Cannot load non-GMM HMM (of type " << type << ") as "
+        << "a Gaussian Mixture Model HMM!" << std::endl;
+  }
+
+  sr.LoadParameter(states, "hmm_states");
+
+  // Load transition matrix.
+  sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+  // Now each emission distribution.
+  hmm.Emission().resize(states, gmm::GMM(1, 1));
+  for (size_t i = 0; i < states; ++i)
+  {
+    std::stringstream s;
+    s << "hmm_emission_" << i << "_gaussians";
+    size_t gaussians;
+    sr.LoadParameter(gaussians, s.str());
+
+    s.str("");
+    // Extract dimensionality.
+    arma::vec meanzero;
+    s << "hmm_emission_" << i << "_gaussian_0_mean";
+    sr.LoadParameter(meanzero, s.str());
+    size_t dimensionality = meanzero.n_elem;
+
+    // Initialize GMM correctly.
+    hmm.Emission()[i] = gmm::GMM(gaussians, dimensionality);
+
+    for (size_t g = 0; g < gaussians; ++g)
+    {
+      s.str("");
+      s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+      sr.LoadParameter(hmm.Emission()[i].Means()[g], s.str());
+
+      s.str("");
+      s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+      sr.LoadParameter(hmm.Emission()[i].Covariances()[g], s.str());
+    }
+
+    s.str("");
+    s << "hmm_emission_" << i << "_weights";
+    sr.LoadParameter(hmm.Emission()[i].Weights(), s.str());
+  }
+
+  hmm.Dimensionality() = hmm.Emission()[0].Dimensionality();
+}
+
+}; // namespace hmm
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list