[mlpack-svn] r10874 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 17 01:39:15 EST 2011
Author: rcurtin
Date: 2011-12-17 01:39:15 -0500 (Sat, 17 Dec 2011)
New Revision: 10874
Added:
mlpack/trunk/src/mlpack/methods/hmm/hmm_generate_main.cpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_loglik_main.cpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
Modified:
mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
Log:
Add main executables for HMMs.
Modified: mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt 2011-12-17 06:37:23 UTC (rev 10873)
+++ mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt 2011-12-17 06:39:15 UTC (rev 10874)
@@ -5,6 +5,8 @@
set(SOURCES
hmm.hpp
hmm_impl.hpp
+ hmm_util.hpp
+ hmm_util_impl.hpp
distributions/discrete_distribution.hpp
distributions/discrete_distribution.cpp
distributions/gaussian_distribution.hpp
@@ -19,3 +21,31 @@
# Append sources (with directory name) to list of all MLPACK sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(hmm_train
+ hmm_train_main.cpp
+)
+target_link_libraries(hmm_train
+ mlpack
+)
+
+add_executable(hmm_loglik
+ hmm_loglik_main.cpp
+)
+target_link_libraries(hmm_loglik
+ mlpack
+)
+
+add_executable(hmm_viterbi
+ hmm_viterbi_main.cpp
+)
+target_link_libraries(hmm_viterbi
+ mlpack
+)
+
+add_executable(hmm_generate
+ hmm_generate_main.cpp
+)
+target_link_libraries(hmm_generate
+ mlpack
+)
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-12-17 06:37:23 UTC (rev 10873)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-12-17 06:39:15 UTC (rev 10874)
@@ -237,6 +237,11 @@
*/
std::vector<Distribution>& Emission() { return emission; }
+ //! Get the dimensionality of observations.
+ size_t Dimensionality() const { return dimensionality; }
+ //! Set the dimensionality of observations.
+ size_t& Dimensionality() { return dimensionality; }
+
private:
// Helper functions.
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_generate_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_generate_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_generate_main.cpp 2011-12-17 06:39:15 UTC (rev 10874)
@@ -0,0 +1,122 @@
+/**
+ * @file hmm_viterbi_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the most probably hidden state sequence of a given observation
+ * sequence for a given HMM.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include "distributions/gaussian_distribution.hpp"
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
+ "utility takes an already-trained HMM (--model_file) and generates a "
+ "random observation sequence and hidden state sequence based on its "
+ "parameters, saving them to the specified files (--output_file and "
+ "--state_file)");
+
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
+
+PARAM_INT("start_state", "Starting state of sequence.", "S", 0);
+PARAM_STRING("output_file", "File to save observation sequence to.", "o",
+ "output.csv");
+PARAM_STRING("state_file", "File to save hidden state sequence to (may be left "
+ "unspecified.", "s", "");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::utilities;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load observations.
+ const string modelFile = CLI::GetParam<string>("model_file");
+ const int length = CLI::GetParam<int>("length");
+ const int startState = CLI::GetParam<int>("start_state");
+
+ if (length <= 0)
+ {
+ Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
+ << "than or equal to 0!" << endl;
+ }
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ mat observations;
+ Col<size_t> sequence;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM> hmm(1, GMM(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ // Save observations.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, observations, true);
+
+ // Do we want to save the hidden sequence?
+ const string sequenceFile = CLI::GetParam<string>("state_file");
+ if (sequenceFile != "")
+ data::Save(sequenceFile, sequence, true);
+}
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_loglik_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_loglik_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_loglik_main.cpp 2011-12-17 06:39:15 UTC (rev 10874)
@@ -0,0 +1,101 @@
+/**
+ * @file hmm_loglik_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the log-likelihood of a given sequence for a given HMM.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include "distributions/gaussian_distribution.hpp"
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Log-Likelihood", "This "
+ "utility takes an already-trained HMM (--model_file) and evaluates the "
+ "log-likelihood of a given sequence of observations (--input_file). The "
+ "computed log-likelihood is given directly to stdout.");
+
+PARAM_STRING_REQ("input_file", "File containing observations,", "i");
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::utilities;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load observations.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ double loglik = 0;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify only one row in observations.
+ if (dataSeq.n_cols == 1)
+ dataSeq = trans(dataSeq);
+
+ if (dataSeq.n_rows > 1)
+ Log::Fatal << "Only one-dimensional discrete observations allowed for "
+ << "discrete HMMs!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM> hmm(1, GMM(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Dimensionality() << ")!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ cout << loglik << endl;
+}
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp 2011-12-17 06:39:15 UTC (rev 10874)
@@ -0,0 +1,349 @@
+/**
+ * @file hmm_train_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable which trains an HMM and saves the trained HMM to file.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include "distributions/gaussian_distribution.hpp"
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
+ "Hidden Markov Model to be trained on labeled or unlabeled data. It "
+ "support three types of HMMs: discrete HMMs, Gaussian HMMs, or GMM HMMs."
+ "\n"
+ "Either one input sequence can be specified (with --input_file), or, a "
+ "file containing files in which input sequences can be found (when "
+ "--input_file and --batch are used together). In addition, labels can be "
+ "provided in the file specified by --label_file, and if --batch is used, "
+ "the file given to --label_file should contain a list of files of labels "
+ "corresponding to the sequences in the file given to --input_file.\n"
+ "\n"
+ "Optionally, a pre-created HMM model can be used as a guess for the "
+ "transition matrix and emission probabilities; this is specifiable with "
+ "--model_file.");
+
+PARAM_STRING_REQ("input_file", "File containing input observations.", "i");
+PARAM_STRING_REQ("type", "Type of HMM: discrete | gaussian | gmm", "t");
+
+PARAM_FLAG("batch", "If true, input_file (and if passed, labels_file) are "
+ "expected to contain a list of files to use as input observation sequences "
+ " (and label sequences).", "b");
+PARAM_INT("states", "Number of hidden states in HMM (necessary, unless "
+ "model_file is specified.", "s", 0);
+PARAM_INT("gaussians", "Number of gaussians in each GMM (necessary when type is"
+ " 'gmm'.", "g", 0);
+PARAM_STRING("model_file", "Pre-existing HMM model (optional).", "m", "");
+PARAM_STRING("labels_file", "Optional file of hidden states, used for "
+ "labeled training.", "l", "");
+PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
+ "output_hmm.xml");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::utilities;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Validate parameters.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string labelsFile = CLI::GetParam<string>("labels_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+ const string outputFile = CLI::GetParam<string>("output_file");
+ const string type = CLI::GetParam<string>("type");
+ const int states = CLI::GetParam<int>("states");
+ const bool batch = CLI::HasParam("batch");
+
+ // Validate number of states.
+ if (states == 0 && modelFile == "")
+ {
+ Log::Fatal << "Must specify number of states if model file is not "
+ << "specified!" << endl;
+ }
+
+ if (states < 0 && modelFile == "")
+ {
+ Log::Fatal << "Invalid number of states (" << states << "); must be greater"
+ << " than or equal to 1." << endl;
+ }
+
+ // Load the dataset(s) and labels.
+ vector<mat> trainSeq;
+ vector<arma::Col<size_t> > labelSeq; // May be empty.
+ if (batch)
+ {
+ // The input file contains a list of files to read.
+ Log::Info << "Reading list of training sequences from '" << inputFile
+ << "'." << endl;
+
+ fstream f(inputFile.c_str(), ios_base::in);
+
+ if (!f.is_open())
+ Log::Fatal << "Could not open '" << inputFile << "' for reading." << endl;
+
+ // Now read each line in.
+ char lineBuf[1024]; // Max 1024 characters... hopefully that is long enough.
+ f.getline(lineBuf, 1024, '\n');
+ while (!f.eof())
+ {
+ Log::Info << "Adding training sequence from '" << lineBuf << "'." << endl;
+
+ // Now read the matrix.
+ trainSeq.push_back(mat());
+ if (labelsFile == "") // Nonfatal in this case.
+ {
+ if (!data::Load(lineBuf, trainSeq.back(), false))
+ {
+ Log::Warn << "Loading training sequence from '" << lineBuf << "' "
+ << "failed. Sequence ignored." << endl;
+ trainSeq.pop_back(); // Remove last element which we did not use.
+ }
+ }
+ else
+ {
+ data::Load(lineBuf, trainSeq.back(), true);
+ }
+
+ // See if we need to transpose the data.
+ if (type == "discrete")
+ {
+ if (trainSeq.back().n_cols == 1)
+ trainSeq.back() = trans(trainSeq.back());
+ }
+
+ f.getline(lineBuf, 1024, '\n');
+ }
+
+ f.close();
+
+ // Now load labels, if we need to.
+ if (labelsFile != "")
+ {
+ f.open(labelsFile.c_str(), ios_base::in);
+
+ if (!f.is_open())
+ Log::Fatal << "Could not open '" << labelsFile << "' for reading."
+ << endl;
+
+ // Now read each line in.
+ f.getline(lineBuf, 1024, '\n');
+ while (!f.eof())
+ {
+ Log::Info << "Adding training sequence labels from '" << lineBuf
+ << "'." << endl;
+
+ // Now read the matrix.
+ Mat<size_t> label;
+ data::Load(lineBuf, label, true); // Fatal on failure.
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ labelSeq.push_back(label.col(0));
+
+ f.getline(lineBuf, 1024, '\n');
+ }
+ }
+ }
+ else
+ {
+ // Only one input file.
+ trainSeq.resize(1);
+ data::Load(inputFile.c_str(), trainSeq[0], true);
+
+ // Do we need to load labels?
+ if (labelsFile != "")
+ {
+ Mat<size_t> label;
+ data::Load(labelsFile, label, true);
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ // Verify the same number of observations as the data.
+ if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
+ Log::Fatal << "Label sequence " << labelSeq.size() << " does not have "
+ << "the same number of points as observation sequence "
+ << labelSeq.size() << "!" << endl;
+
+ labelSeq.push_back(label.col(0));
+ }
+ }
+
+ // Now, train the HMM, since we have loaded the input data.
+ if (type == "discrete")
+ {
+ // Verify observations are valid.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows > 1)
+ Log::Fatal << "Error in training sequence " << i << ": only "
+ << "one-dimensional discrete observations allowed for discrete "
+ << "HMMs!" << endl;
+
+ // Do we have a model to preload?
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+ }
+ else // New model.
+ {
+ // Maximum observation is necessary so we know how to train the discrete
+ // distribution.
+ size_t maxEmission = 0;
+ for (vector<mat>::iterator it = trainSeq.begin(); it != trainSeq.end();
+ ++it)
+ {
+ size_t maxSeq = size_t(as_scalar(max(trainSeq[0], 1))) + 1;
+ if (maxSeq > maxEmission)
+ maxEmission = maxSeq;
+ }
+
+ Log::Info << maxEmission << " discrete observations in the input data."
+ << endl;
+
+ // Create HMM object.
+ hmm = HMM<DiscreteDistribution>(size_t(states),
+ DiscreteDistribution(maxEmission));
+ }
+
+ // Do we have labels?
+ if (labelsFile == "")
+ hmm.Train(trainSeq); // Unsupervised training.
+ else
+ hmm.Train(trainSeq, labelSeq); // Supervised training.
+
+ // Finally, save the model. This should later be integrated into the HMM
+ // class itself.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else if (type == "gaussian")
+ {
+ // Create HMM object.
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ // Do we have a model to load?
+ size_t dimensionality = 0;
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+
+ dimensionality = hmm.Emission()[0].Mean().n_elem;
+ }
+ else
+ {
+ // Find dimension of the data.
+ dimensionality = trainSeq[0].n_rows;
+
+ hmm = HMM<GaussianDistribution>(size_t(states),
+ GaussianDistribution(dimensionality));
+ }
+
+ // Verify dimensionality of data.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != dimensionality)
+ Log::Fatal << "Observation sequence " << i << " dimensionality ("
+ << trainSeq[i].n_rows << " is incorrect (should be "
+ << dimensionality << ")!" << endl;
+
+ // Now run the training.
+ if (labelsFile == "")
+ hmm.Train(trainSeq); // Unsupervised training.
+ else
+ hmm.Train(trainSeq, labelSeq); // Supervised training.
+
+ // Finally, save the model. This should later be integrated into th HMM
+ // class itself.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else if (type == "gmm")
+ {
+ // Create HMM object.
+ HMM<GMM> hmm(1, GMM(1, 1));
+
+ // Do we have a model to load?
+ size_t dimensionality = 0;
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+
+ dimensionality = hmm.Emission()[0].Dimensionality();
+ }
+ else
+ {
+ // Find dimension of the data.
+ dimensionality = trainSeq[0].n_rows;
+
+ const int gaussians = CLI::GetParam<int>("gaussians");
+
+ if (gaussians == 0)
+ Log::Fatal << "Number of gaussians for each GMM must be specified (-g) "
+ << "when type = 'gmm'!" << endl;
+
+ if (gaussians < 0)
+ Log::Fatal << "Invalid number of gaussians (" << gaussians << "); must "
+ << "be greater than or equal to 1." << endl;
+
+ hmm = HMM<GMM>(size_t(states), GMM(size_t(gaussians), dimensionality));
+ }
+
+ // Verify dimensionality of data.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != dimensionality)
+ Log::Fatal << "Observation sequence " << i << " dimensionality ("
+ << trainSeq[i].n_rows << " is incorrect (should be "
+ << dimensionality << ")!" << endl;
+
+ // Now run the training.
+ if (labelsFile == "")
+ {
+ Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
+ << "going to produce good results!" << endl;
+ hmm.Train(trainSeq);
+ }
+ else
+ {
+ hmm.Train(trainSeq, labelSeq);
+ }
+
+ // Save model.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
+ << "'gaussian', or 'gmm'." << endl;
+ }
+}
Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_viterbi_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_viterbi_main.cpp 2011-12-17 06:39:15 UTC (rev 10874)
@@ -0,0 +1,107 @@
+/**
+ * @file hmm_viterbi_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the most probably hidden state sequence of a given observation
+ * sequence for a given HMM.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include "distributions/gaussian_distribution.hpp"
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Viterbi State Prediction", "This "
+ "utility takes an already-trained HMM (--model_file) and evaluates the "
+ "most probably hidden state sequence of a given sequence of observations "
+ "(--input_file), using the Viterbi algorithm. The computed state sequence "
+ "is saved to the specified output file (--output_file).");
+
+PARAM_STRING_REQ("input_file", "File containing observations,", "i");
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING("output_file", "File to save predicted state sequence to.", "o",
+ "output.csv");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::utilities;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load observations.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ arma::Col<size_t> sequence;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify only one row in observations.
+ if (dataSeq.n_cols == 1)
+ dataSeq = trans(dataSeq);
+
+ if (dataSeq.n_rows > 1)
+ Log::Fatal << "Only one-dimensional discrete observations allowed for "
+ << "discrete HMMs!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM> hmm(1, GMM(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Dimensionality() << ")!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ // Save output.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, sequence, true);
+}
More information about the mlpack-svn
mailing list