[mlpack-svn] r10879 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 17 02:19:02 EST 2011
Author: rcurtin
Date: 2011-12-17 02:19:01 -0500 (Sat, 17 Dec 2011)
New Revision: 10879
Added:
mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp
Log:
Utilities for loading/saving HMMs; to be deprecated later.
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_util.hpp 2011-12-17 07:19:01 UTC (rev 10879)
@@ -0,0 +1,42 @@
+/**
+ * @file hmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Save/load utilities for HMMs. This should be eventually merged into the HMM
+ * class itself.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+
+#include "hmm.hpp"
+
+namespace mlpack {
+namespace hmm {
+
+/**
+ * Save an HMM to file. This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr);
+
+/**
+ * Load an HMM from file. This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr);
+
+}; // namespace hmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "hmm_util_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_util_impl.hpp 2011-12-17 07:19:01 UTC (rev 10879)
@@ -0,0 +1,240 @@
+/**
+ * @file hmm_util_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of HMM load/save functions.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr)
+{
+ Log::Fatal << "HMM save not implemented for arbitrary distributions."
+ << std::endl;
+}
+
+template<>
+void SaveHMM(const HMM<distribution::DiscreteDistribution>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type = "discrete";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_distribution_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Probabilities(), s.str());
+ }
+}
+
+template<>
+void SaveHMM(const HMM<distribution::GaussianDistribution>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type = "gaussian";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_mean_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Mean(), s.str());
+
+ s.str("");
+ s << "hmm_emission_covariance_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Covariance(), s.str());
+ }
+}
+
+template<>
+void SaveHMM(const HMM<gmm::GMM>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type = "gmm";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_" << i << "_gaussians";
+ sr.SaveParameter(hmm.Emission()[i].Gaussians(), s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_weights";
+ sr.SaveParameter(hmm.Emission()[i].Weights(), s.str());
+
+ for (size_t g = 0; g < hmm.Emission()[i].Gaussians(); ++g)
+ {
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+ sr.SaveParameter(hmm.Emission()[i].Means()[g], s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+ sr.SaveParameter(hmm.Emission()[i].Covariances()[g], s.str());
+ }
+ }
+}
+
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, utilities::SaveRestoreUtility& sr)
+{
+ Log::Fatal << "HMM load not implemented for arbitrary distributions."
+ << std::endl;
+}
+
+template<>
+void LoadHMM(HMM<distribution::DiscreteDistribution>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "discrete")
+ {
+ Log::Fatal << "Cannot load non-discrete HMM (of type " << type << ") as "
+ << "discrete HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states);
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_distribution_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Probabilities(), s.str());
+ }
+
+ hmm.Dimensionality() = 1;
+}
+
+template<>
+void LoadHMM(HMM<distribution::GaussianDistribution>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "gaussian")
+ {
+ Log::Fatal << "Cannot load non-Gaussian HMM (of type " << type << ") as "
+ << "a Gaussian HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states);
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_mean_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());
+
+ s.str("");
+ s << "hmm_emission_covariance_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
+ }
+
+ hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
+}
+
+template<>
+void LoadHMM(HMM<gmm::GMM>& hmm,
+ utilities::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "gmm")
+ {
+ Log::Fatal << "Cannot load non-GMM HMM (of type " << type << ") as "
+ << "a Gaussian Mixture Model HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states, gmm::GMM(1, 1));
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_" << i << "_gaussians";
+ size_t gaussians;
+ sr.LoadParameter(gaussians, s.str());
+
+ s.str("");
+ // Extract dimensionality.
+ arma::vec meanzero;
+ s << "hmm_emission_" << i << "_gaussian_0_mean";
+ sr.LoadParameter(meanzero, s.str());
+ size_t dimensionality = meanzero.n_elem;
+
+ // Initialize GMM correctly.
+ hmm.Emission()[i] = gmm::GMM(gaussians, dimensionality);
+
+ for (size_t g = 0; g < gaussians; ++g)
+ {
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+ sr.LoadParameter(hmm.Emission()[i].Means()[g], s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+ sr.LoadParameter(hmm.Emission()[i].Covariances()[g], s.str());
+ }
+
+ s.str("");
+ s << "hmm_emission_" << i << "_weights";
+ sr.LoadParameter(hmm.Emission()[i].Weights(), s.str());
+ }
+
+ hmm.Dimensionality() = hmm.Emission()[0].Dimensionality();
+}
+
+}; // namespace hmm
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list