[mlpack-git] master: Refactor HMM programs to use boost::serialization. (7f3c228)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:05:03 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc
>---------------------------------------------------------------
commit 7f3c2286ff182b128ae667edf392d43a865072e7
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Jul 12 13:34:44 2015 +0000
Refactor HMM programs to use boost::serialization.
>---------------------------------------------------------------
7f3c2286ff182b128ae667edf392d43a865072e7
src/mlpack/methods/hmm/CMakeLists.txt | 2 +
src/mlpack/methods/hmm/hmm_generate_main.cpp | 107 +++-----
src/mlpack/methods/hmm/hmm_loglik_main.cpp | 96 +++----
src/mlpack/methods/hmm/hmm_train_main.cpp | 360 ++++++++++++---------------
src/mlpack/methods/hmm/hmm_util.hpp | 39 +++
src/mlpack/methods/hmm/hmm_util_impl.hpp | 161 ++++++++++++
src/mlpack/methods/hmm/hmm_viterbi_main.cpp | 89 +++----
7 files changed, 459 insertions(+), 395 deletions(-)
diff --git a/src/mlpack/methods/hmm/CMakeLists.txt b/src/mlpack/methods/hmm/CMakeLists.txt
index 179a32b..fd52249 100644
--- a/src/mlpack/methods/hmm/CMakeLists.txt
+++ b/src/mlpack/methods/hmm/CMakeLists.txt
@@ -5,6 +5,8 @@ set(SOURCES
hmm_impl.hpp
hmm_regression.hpp
hmm_regression_impl.hpp
+ hmm_util.hpp
+ hmm_util_impl.hpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/hmm/hmm_generate_main.cpp b/src/mlpack/methods/hmm/hmm_generate_main.cpp
index 2a50d06..d0de3b0 100644
--- a/src/mlpack/methods/hmm/hmm_generate_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_generate_main.cpp
@@ -19,7 +19,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
"parameters, saving them to the specified files (--output_file and "
"--state_file)");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
PARAM_INT("start_state", "Starting state of sequence.", "t", 0);
@@ -38,92 +38,51 @@ using namespace mlpack::math;
using namespace arma;
using namespace std;
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function which can take arbitrary HMM types.
+struct Generate
{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Set random seed.
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) time(NULL));
-
- // Load observations.
- const string modelFile = CLI::GetParam<string>("model_file");
- const int length = CLI::GetParam<int>("length");
- const int startState = CLI::GetParam<int>("start_state");
-
- if (length <= 0)
+ template<typename HMMType>
+ static void Apply(HMMType& hmm, void* /* extraInfo */)
{
- Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
- << "than or equal to 0!" << endl;
- }
+ mat observations;
+ Col<size_t> sequence;
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string emissionType;
- sr.LoadParameter(emissionType, "emission_type");
-
- mat observations;
- Col<size_t> sequence;
- if (emissionType == "DiscreteDistribution")
- {
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
- hmm.Load(sr);
+ // Load the parameters.
+ const size_t startState = (size_t) CLI::GetParam<int>("start_state");
+ const size_t length = (size_t) CLI::GetParam<int>("length");
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
+ Log::Info << "Generating sequence of length " << length << "..." << endl;
+ if (startState >= hmm.Transition().n_rows)
Log::Fatal << "Invalid start state (" << startState << "); must be "
<< "between 0 and number of states (" << hmm.Transition().n_rows
<< ")!" << endl;
- }
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
- }
- else if (emissionType == "GaussianDistribution")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
- hmm.Load(sr);
+ hmm.Generate(length, observations, sequence, startState);
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
- Log::Fatal << "Invalid start state (" << startState << "); must be "
- << "between 0 and number of states (" << hmm.Transition().n_rows
- << ")!" << endl;
- }
+ // Now save the output.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, observations, true);
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ // Do we want to save the hidden sequence?
+ const string sequenceFile = CLI::GetParam<string>("state_file");
+ if (sequenceFile != "")
+ data::Save(sequenceFile, sequence, true);
}
- else if (emissionType == "GMM")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
- hmm.Load(sr);
+};
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
- Log::Fatal << "Invalid start state (" << startState << "); must be "
- << "between 0 and number of states (" << hmm.Transition().n_rows
- << ")!" << endl;
- }
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
- }
+ // Set random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
else
- {
- Log::Fatal << "Unknown HMM type '" << emissionType << "' in file '" << modelFile
- << "'!" << endl;
- }
-
- // Save observations.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, observations, true);
-
- // Do we want to save the hidden sequence?
- const string sequenceFile = CLI::GetParam<string>("state_file");
- if (sequenceFile != "")
- data::Save(sequenceFile, sequence, true);
+ RandomSeed((size_t) time(NULL));
- return 0;
+ // Load model, and perform the generation.
+ const string modelFile = CLI::GetParam<string>("model_file");
+ LoadHMMAndPerformAction<Generate>(modelFile);
}
diff --git a/src/mlpack/methods/hmm/hmm_loglik_main.cpp b/src/mlpack/methods/hmm/hmm_loglik_main.cpp
index 6946567..2692a17 100644
--- a/src/mlpack/methods/hmm/hmm_loglik_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_loglik_main.cpp
@@ -17,7 +17,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Log-Likelihood", "This "
"computed log-likelihood is given directly to stdout.");
PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
using namespace mlpack;
using namespace mlpack::hmm;
@@ -27,74 +27,44 @@ using namespace mlpack::gmm;
using namespace arma;
using namespace std;
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Loglik
{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Load observations.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string modelFile = CLI::GetParam<string>("model_file");
-
- mat dataSeq;
- data::Load(inputFile, dataSeq, true);
-
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string type;
- sr.LoadParameter(type, "hmm_type");
-
- double loglik = 0;
- if (type == "discrete")
+ template<typename HMMType>
+ static void Apply(HMMType& hmm, void* /* extraInfo */)
{
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify only one row in observations.
- if (dataSeq.n_cols == 1)
- dataSeq = trans(dataSeq);
-
- if (dataSeq.n_rows > 1)
- Log::Fatal << "Only one-dimensional discrete observations allowed for "
- << "discrete HMMs!" << endl;
-
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else if (type == "gaussian")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
- LoadHMM(hmm, sr);
+ // Load the data sequence.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
+
+ // Detect if we need to transpose the data, in the case where the input data
+ // has one dimension.
+ if ((dataSeq.n_cols == 1) && (hmm.Emission()[0].Dimensionality() == 1))
+ {
+ Log::Info << "Data sequence appears to be transposed; correcting."
+ << endl;
+ dataSeq = dataSeq.t();
+ }
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else if (type == "gmm")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
- LoadHMM(hmm, sr);
-
- // Verify correct dimensionality.
if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
+ Log::Fatal << "Dimensionality of sequence (" << dataSeq.n_rows << ") is "
+ << "not equal to the dimensionality of the HMM ("
<< hmm.Emission()[0].Dimensionality() << ")!" << endl;
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else
- {
- Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
- << "'!" << endl;
+ const double loglik = hmm.LogLikelihood(dataSeq);
+
+ cout << loglik << endl;
}
+};
- cout << loglik << endl;
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load model, and calculate the log-likelihood of the sequence.
+ const string modelFile = CLI::GetParam<string>("model_file");
+ LoadHMMAndPerformAction<Loglik>(modelFile);
}
diff --git a/src/mlpack/methods/hmm/hmm_train_main.cpp b/src/mlpack/methods/hmm/hmm_train_main.cpp
index 02c9df6..87d5c04 100644
--- a/src/mlpack/methods/hmm/hmm_train_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_train_main.cpp
@@ -43,7 +43,7 @@ PARAM_INT("gaussians", "Number of gaussians in each GMM (necessary when type is"
PARAM_STRING("model_file", "Pre-existing HMM model (optional).", "m", "");
PARAM_STRING("labels_file", "Optional file of hidden states, used for "
"labeled training.", "l", "");
-PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
+PARAM_STRING("output_file", "File to save trained HMM to.", "o",
"output_hmm.xml");
PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
@@ -57,6 +57,106 @@ using namespace mlpack::math;
using namespace arma;
using namespace std;
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Train
+{
+ template<typename HMMType>
+ static void Apply(HMMType& hmm, vector<mat>* trainSeqPtr)
+ {
+ const bool batch = CLI::HasParam("batch");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+
+ // Do we need to replace the tolerance?
+ if (CLI::HasParam("tolerance"))
+ hmm.Tolerance() = tolerance;
+
+ const string labelsFile = CLI::GetParam<string>("labels_file");
+
+ // Verify that the dimensionality of our observations is the same as the
+ // dimensionality of our HMM's emissions.
+ vector<mat>& trainSeq = *trainSeqPtr;
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != hmm.Emission()[0].Dimensionality())
+ Log::Fatal << "Dimensionality of training sequence " << i << " ("
+ << trainSeq[i].n_rows << ") is not equal to the dimensionality of "
+ << "the HMM (" << hmm.Emission()[0].Dimensionality() << ")!"
+ << endl;
+
+ vector<arma::Col<size_t> > labelSeq; // May be empty.
+ if (labelsFile != "")
+ {
+ // Do we have multiple label files to load?
+ char lineBuf[1024];
+ if (batch)
+ {
+ fstream f(labelsFile);
+
+ if (!f.is_open())
+ Log::Fatal << "Could not open '" << labelsFile << "' for reading."
+ << endl;
+
+ // Now read each line in.
+ f.getline(lineBuf, 1024, '\n');
+ while (!f.eof())
+ {
+ Log::Info << "Adding training sequence labels from '" << lineBuf
+ << "'." << endl;
+
+ // Now read the matrix.
+ Mat<size_t> label;
+ data::Load(lineBuf, label, true); // Fatal on failure.
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ labelSeq.push_back(label.col(0));
+
+ f.getline(lineBuf, 1024, '\n');
+ }
+
+ f.close();
+ }
+ else
+ {
+ Mat<size_t> label;
+ data::Load(labelsFile, label, true);
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ // Verify the same number of observations as the data.
+ if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
+ Log::Fatal << "Label sequence " << labelSeq.size() << " does not have"
+ << " the same number of points as observation sequence "
+ << labelSeq.size() << "!" << endl;
+
+ labelSeq.push_back(label.col(0));
+ }
+
+ // Now perform the training with labels.
+ hmm.Train(trainSeq, labelSeq);
+ }
+ else
+ {
+ // Perform unsupervised training.
+ hmm.Train(trainSeq);
+ }
+
+ // Save the model.
+ const string modelFile = CLI::GetParam<string>("model_file");
+ SaveHMM(hmm, modelFile);
+ }
+};
+
int main(int argc, char** argv)
{
// Parse command line options.
@@ -69,31 +169,33 @@ int main(int argc, char** argv)
RandomSeed((size_t) time(NULL));
// Validate parameters.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string labelsFile = CLI::GetParam<string>("labels_file");
const string modelFile = CLI::GetParam<string>("model_file");
- const string outputFile = CLI::GetParam<string>("output_file");
+ const string inputFile = CLI::GetParam<string>("input_file");
const string type = CLI::GetParam<string>("type");
- const int states = CLI::GetParam<int>("states");
- const bool batch = CLI::HasParam("batch");
+ const size_t states = CLI::GetParam<int>("states");
const double tolerance = CLI::GetParam<double>("tolerance");
+ const bool batch = CLI::HasParam("batch");
- // Validate number of states.
- if (states == 0 && modelFile == "")
- {
- Log::Fatal << "Must specify number of states if model file is not "
- << "specified!" << endl;
- }
+ // Verify that either a model or a type was given.
+ if (modelFile == "" && type == "")
+ Log::Fatal << "No model file specified and no HMM type given! At least "
+ << "one is required." << endl;
- if (states < 0 && modelFile == "")
+ // If no model is specified, make sure we are training with valid parameters.
+ if (modelFile == "")
{
- Log::Fatal << "Invalid number of states (" << states << "); must be greater"
- << " than or equal to 1." << endl;
+ // Validate number of states.
+ if (states == 0)
+ Log::Fatal << "Must specify number of states if model file is not "
+ << "specified!" << endl;
}
- // Load the dataset(s) and labels.
+ if (modelFile != "" && CLI::HasParam("tolerance"))
+ Log::Info << "Tolerance of existing model in '" << modelFile << "' will be "
+ << "replaced with specified tolerance of " << tolerance << "." << endl;
+
+ // Load the input data.
vector<mat> trainSeq;
- vector<arma::Col<size_t> > labelSeq; // May be empty.
if (batch)
{
// The input file contains a list of files to read.
@@ -103,30 +205,20 @@ int main(int argc, char** argv)
fstream f(inputFile.c_str(), ios_base::in);
if (!f.is_open())
- Log::Fatal << "Could not open '" << inputFile << "' for reading." << endl;
+ Log::Fatal << "Could not open '" << inputFile << "' for reading."
+ << endl;
// Now read each line in.
- char lineBuf[1024]; // Max 1024 characters... hopefully that is long enough.
+ char lineBuf[1024]; // Max 1024 characters... hopefully long enough.
f.getline(lineBuf, 1024, '\n');
while (!f.eof())
{
- Log::Info << "Adding training sequence from '" << lineBuf << "'." << endl;
+ Log::Info << "Adding training sequence from '" << lineBuf << "'."
+ << endl;
// Now read the matrix.
trainSeq.push_back(mat());
- if (labelsFile == "") // Nonfatal in this case.
- {
- if (!data::Load(lineBuf, trainSeq.back(), false))
- {
- Log::Warn << "Loading training sequence from '" << lineBuf << "' "
- << "failed. Sequence ignored." << endl;
- trainSeq.pop_back(); // Remove last element which we did not use.
- }
- }
- else
- {
- data::Load(lineBuf, trainSeq.back(), true);
- }
+ data::Load(lineBuf, trainSeq.back(), true); // Fatal on failure.
// See if we need to transpose the data.
if (type == "discrete")
@@ -139,89 +231,25 @@ int main(int argc, char** argv)
}
f.close();
-
- // Now load labels, if we need to.
- if (labelsFile != "")
- {
- f.open(labelsFile.c_str(), ios_base::in);
-
- if (!f.is_open())
- Log::Fatal << "Could not open '" << labelsFile << "' for reading."
- << endl;
-
- // Now read each line in.
- f.getline(lineBuf, 1024, '\n');
- while (!f.eof())
- {
- Log::Info << "Adding training sequence labels from '" << lineBuf
- << "'." << endl;
-
- // Now read the matrix.
- Mat<size_t> label;
- data::Load(lineBuf, label, true); // Fatal on failure.
-
- // Ensure that matrix only has one column.
- if (label.n_rows == 1)
- label = trans(label);
-
- if (label.n_cols > 1)
- Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
- labelSeq.push_back(label.col(0));
-
- f.getline(lineBuf, 1024, '\n');
- }
- }
}
else
{
// Only one input file.
trainSeq.resize(1);
data::Load(inputFile, trainSeq[0], true);
-
- // Do we need to load labels?
- if (labelsFile != "")
- {
- Mat<size_t> label;
- data::Load(labelsFile, label, true);
-
- // Ensure that matrix only has one column.
- if (label.n_rows == 1)
- label = trans(label);
-
- if (label.n_cols > 1)
- Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
- // Verify the same number of observations as the data.
- if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
- Log::Fatal << "Label sequence " << labelSeq.size() << " does not have "
- << "the same number of points as observation sequence "
- << labelSeq.size() << "!" << endl;
-
- labelSeq.push_back(label.col(0));
- }
}
- // Now, train the HMM, since we have loaded the input data.
- if (type == "discrete")
+ // If we have a model file, we can autodetect the type.
+ if (modelFile != "")
{
- // Verify observations are valid.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows > 1)
- Log::Fatal << "Error in training sequence " << i << ": only "
- << "one-dimensional discrete observations allowed for discrete "
- << "HMMs!" << endl;
-
- // Do we have a model to preload?
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1), tolerance);
+ LoadHMMAndPerformAction<Train>(modelFile, &trainSeq);
+ }
+ else
+ {
+ // We need to read in the type and build the HMM by hand.
+ const string type = CLI::GetParam<string>("type");
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
- }
- else // New model.
+ if (type == "discrete")
{
// Maximum observation is necessary so we know how to train the discrete
// distribution.
@@ -238,84 +266,34 @@ int main(int argc, char** argv)
<< endl;
// Create HMM object.
- hmm = HMM<DiscreteDistribution>(size_t(states),
+ HMM<DiscreteDistribution> hmm(size_t(states),
DiscreteDistribution(maxEmission), tolerance);
- }
-
- // Do we have labels?
- if (labelsFile == "")
- hmm.Train(trainSeq); // Unsupervised training.
- else
- hmm.Train(trainSeq, labelSeq); // Supervised training.
-
- // Finally, save the model. This should later be integrated into the HMM
- // class itself.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else if (type == "gaussian")
- {
- // Create HMM object.
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1), tolerance);
- // Do we have a model to load?
- size_t dimensionality = 0;
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
-
- dimensionality = hmm.Emission()[0].Mean().n_elem;
+ // Now train it. Pass the already-loaded training data.
+ Train::Apply(hmm, &trainSeq);
}
- else
+ else if (type == "gaussian")
{
// Find dimension of the data.
- dimensionality = trainSeq[0].n_rows;
-
- hmm = HMM<GaussianDistribution>(size_t(states),
- GaussianDistribution(dimensionality), tolerance);
- }
-
- // Verify dimensionality of data.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows != dimensionality)
- Log::Fatal << "Observation sequence " << i << " dimensionality ("
- << trainSeq[i].n_rows << " is incorrect (should be "
- << dimensionality << ")!" << endl;
-
- // Now run the training.
- if (labelsFile == "")
- hmm.Train(trainSeq); // Unsupervised training.
- else
- hmm.Train(trainSeq, labelSeq); // Supervised training.
+ const size_t dimensionality = trainSeq[0].n_rows;
- // Finally, save the model. This should later be integrated into th HMM
- // class itself.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else if (type == "gmm")
- {
- // Create HMM object.
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
+ // Verify dimensionality of data.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != dimensionality)
+ Log::Fatal << "Observation sequence " << i << " dimensionality ("
+ << trainSeq[i].n_rows << " is incorrect (should be "
+ << dimensionality << ")!" << endl;
- // Do we have a model to load?
- size_t dimensionality = 0;
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
+ HMM<GaussianDistribution> hmm(size_t(states),
+ GaussianDistribution(dimensionality), tolerance);
- dimensionality = hmm.Emission()[0].Dimensionality();
+ // Now train it.
+ Train::Apply(hmm, &trainSeq);
}
- else
+ else if (type == "gmm")
{
// Find dimension of the data.
- dimensionality = trainSeq[0].n_rows;
+ const size_t dimensionality = trainSeq[0].n_rows;
const int gaussians = CLI::GetParam<int>("gaussians");
@@ -327,37 +305,21 @@ int main(int argc, char** argv)
Log::Fatal << "Invalid number of gaussians (" << gaussians << "); must "
<< "be greater than or equal to 1." << endl;
- hmm = HMM<GMM<> >(size_t(states), GMM<>(size_t(gaussians),
- dimensionality), tolerance);
- }
+ // Create HMM object.
+ HMM<GMM<>> hmm(size_t(states), GMM<>(size_t(gaussians), dimensionality),
+ tolerance);
- // Verify dimensionality of data.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows != dimensionality)
- Log::Fatal << "Observation sequence " << i << " dimensionality ("
- << trainSeq[i].n_rows << " is incorrect (should be "
- << dimensionality << ")!" << endl;
+ // Issue a warning if the user didn't give labels.
+ if (!CLI::HasParam("label_file"))
+ Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
+ << "going to produce good results!" << endl;
- // Now run the training.
- if (labelsFile == "")
- {
- Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
- << "going to produce good results!" << endl;
- hmm.Train(trainSeq);
+ Train::Apply(hmm, &trainSeq);
}
else
{
- hmm.Train(trainSeq, labelSeq);
+ Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
+ << "'gaussian', or 'gmm'." << endl;
}
-
- // Save model.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else
- {
- Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
- << "'gaussian', or 'gmm'." << endl;
}
}
diff --git a/src/mlpack/methods/hmm/hmm_util.hpp b/src/mlpack/methods/hmm/hmm_util.hpp
new file mode 100644
index 0000000..46ec818
--- /dev/null
+++ b/src/mlpack/methods/hmm/hmm_util.hpp
@@ -0,0 +1,39 @@
+/**
+ * @file hmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Utility to read HMM type from file.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+//! HMMType, to be stored on disk. This is of type char, which is one byte.
+//! (I'm not sure what will happen on systems where one byte is not eight bits.)
+enum HMMType : char
+{
+ DiscreteHMM = 0,
+ GaussianHMM,
+ GaussianMixtureModelHMM
+};
+
+//! ActionType should implement static void Apply(HMMType&).
+template<typename ActionType, typename ExtraInfoType = void>
+void LoadHMMAndPerformAction(const std::string& modelFile,
+ ExtraInfoType* x = NULL);
+
+//! Save an HMM to a file. The file must also encode what type of HMM is being
+//! stored.
+template<typename HMMType>
+void SaveHMM(HMMType& hmm, const std::string& modelFile);
+
+} // namespace hmm
+} // namespace mlpack
+
+#include "hmm_util_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/hmm/hmm_util_impl.hpp b/src/mlpack/methods/hmm/hmm_util_impl.hpp
new file mode 100644
index 0000000..e1b2de7
--- /dev/null
+++ b/src/mlpack/methods/hmm/hmm_util_impl.hpp
@@ -0,0 +1,161 @@
+/**
+ * @file hmm_util_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of HMM utilities to load arbitrary HMM types.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/methods/hmm/hmm.hpp>
+#include <mlpack/methods/gmm/gmm.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+// Forward declarations of utility functions.
+
+// Set up the archive for deserialization.
+template<typename ActionType, typename ArchiveType, typename ExtraInfoType>
+void LoadHMMAndPerformActionHelper(const std::string& modelFile,
+ ExtraInfoType* x = NULL);
+
+// Actually deserialize into the correct type.
+template<typename ActionType,
+ typename ArchiveType,
+ typename HMMType,
+ typename ExtraInfoType>
+void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x = NULL);
+
+template<typename ActionType, typename ExtraInfoType>
+void LoadHMMAndPerformAction(const std::string& modelFile,
+ ExtraInfoType* x)
+{
+ using namespace boost::archive;
+
+ const std::string extension = data::Extension(modelFile);
+ if (extension == "xml")
+ LoadHMMAndPerformActionHelper<ActionType, xml_iarchive>(modelFile, x);
+ else if (extension == "bin")
+ LoadHMMAndPerformActionHelper<ActionType, binary_iarchive>(modelFile, x);
+ else if (extension == "txt")
+ LoadHMMAndPerformActionHelper<ActionType, text_iarchive>(modelFile, x);
+ else
+ Log::Fatal << "Unknown extension '" << extension << "' for HMM model file "
+ << "(known: 'xml', 'txt', 'bin')." << std::endl;
+}
+
+template<typename ActionType,
+ typename ArchiveType,
+ typename ExtraInfoType>
+void LoadHMMAndPerformActionHelper(const std::string& modelFile,
+ ExtraInfoType* x)
+{
+ std::ifstream ifs(modelFile);
+ ArchiveType ar(ifs);
+
+ // Read in the unsigned integer that denotes the type of the model.
+ char type;
+ ar >> data::CreateNVP(type, "hmm_type");
+
+ using namespace mlpack::distribution;
+
+ switch (type)
+ {
+ case HMMType::DiscreteHMM:
+ DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+ HMM<DiscreteDistribution>>(ar, x);
+
+ case HMMType::GaussianHMM:
+ DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+ HMM<GaussianDistribution>>(ar, x);
+
+ case HMMType::GaussianMixtureModelHMM:
+ DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+ HMM<gmm::GMM<>>>(ar, x);
+
+ default:
+ Log::Fatal << "Unknown HMM type '" << (unsigned int) type << "'!"
+ << std::endl;
+ }
+}
+
+template<typename ActionType,
+ typename ArchiveType,
+ typename HMMType,
+ typename ExtraInfoType>
+void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x)
+{
+ // Extract the HMM and perform the action.
+ HMMType hmm;
+ ar >> data::CreateNVP(hmm, "hmm");
+ ActionType::Apply(hmm, x);
+}
+
+// Helper function.
+template<typename ArchiveType, typename HMMType>
+void SaveHMMHelper(HMMType& hmm, const std::string& modelFile);
+
+template<typename HMMType>
+char GetHMMType();
+
+template<typename HMMType>
+void SaveHMM(HMMType& hmm, const std::string& modelFile)
+{
+ using namespace boost::archive;
+
+ const std::string extension = data::Extension(modelFile);
+ if (extension == "xml")
+ SaveHMMHelper<xml_oarchive>(hmm, modelFile);
+ else if (extension == "bin")
+ SaveHMMHelper<binary_oarchive>(hmm, modelFile);
+ else if (extension == "txt")
+ SaveHMMHelper<text_oarchive>(hmm, modelFile);
+ else
+ Log::Fatal << "Unknown extension '" << extension << "' for HMM model file."
+ << std::endl;
+}
+
+template<typename ArchiveType, typename HMMType>
+void SaveHMMHelper(HMMType& hmm, const std::string& modelFile)
+{
+ std::ofstream ofs(modelFile);
+ ArchiveType ar(ofs);
+
+ // Write out the unsigned integer that denotes the type of the model.
+ char type = GetHMMType<HMMType>();
+ if (type == char(-1))
+ Log::Fatal << "Unknown HMM type given to SaveHMM()!" << std::endl;
+
+ ar << data::CreateNVP(type, "hmm_type");
+ ar << data::CreateNVP(hmm, "hmm");
+}
+
+// Utility functions to turn a type into something we can store.
+template<typename HMMType>
+char GetHMMType() { return char(-1); }
+
+template<>
+char GetHMMType<HMM<distribution::DiscreteDistribution>>()
+{
+ return HMMType::DiscreteHMM;
+}
+
+template<>
+char GetHMMType<HMM<distribution::GaussianDistribution>>()
+{
+ return HMMType::GaussianHMM;
+}
+
+template<>
+char GetHMMType<HMM<gmm::GMM<>>>()
+{
+ return HMMType::GaussianMixtureModelHMM;
+}
+
+} // namespace hmm
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/hmm/hmm_viterbi_main.cpp b/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
index de13bee..34933af 100644
--- a/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
@@ -19,7 +19,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Viterbi State Prediction", "This "
"is saved to the specified output file (--output_file).");
PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
PARAM_STRING("output_file", "File to save predicted state sequence to.", "o",
"output.csv");
@@ -31,60 +31,26 @@ using namespace mlpack::gmm;
using namespace arma;
using namespace std;
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Viterbi
{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Load observations.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string modelFile = CLI::GetParam<string>("model_file");
-
- mat dataSeq;
- data::Load(inputFile, dataSeq, true);
-
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string type;
- sr.LoadParameter(type, "hmm_type");
-
- arma::Col<size_t> sequence;
- if (type == "discrete")
+ template<typename HMMType>
+ static void Apply(HMMType& hmm, void* /* extraInfo */)
{
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify only one row in observations.
- if (dataSeq.n_cols == 1)
- dataSeq = trans(dataSeq);
-
- if (dataSeq.n_rows > 1)
- Log::Fatal << "Only one-dimensional discrete observations allowed for "
- << "discrete HMMs!" << endl;
-
- hmm.Predict(dataSeq, sequence);
- }
- else if (type == "gaussian")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
- LoadHMM(hmm, sr);
+ // Load observations.
+ const string inputFile = CLI::GetParam<string>("input_file");
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
- hmm.Predict(dataSeq, sequence);
- }
- else if (type == "gmm")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
- LoadHMM(hmm, sr);
+ // See if transposing the data could make it the right dimensionality.
+ if ((dataSeq.n_cols == 1) && (hmm.Emission()[0].Dimensionality() == 1))
+ {
+ Log::Info << "Data sequence appears to be transposed; correcting."
+ << endl;
+ dataSeq = dataSeq.t();
+ }
// Verify correct dimensionality.
if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
@@ -92,15 +58,20 @@ int main(int argc, char** argv)
<< "does not match HMM Gaussian dimensionality ("
<< hmm.Emission()[0].Dimensionality() << ")!" << endl;
+ arma::Col<size_t> sequence;
hmm.Predict(dataSeq, sequence);
+
+ // Save output.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, sequence, true);
}
- else
- {
- Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
- << "'!" << endl;
- }
+};
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
- // Save output.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, sequence, true);
+ const string modelFile = CLI::GetParam<string>("model_file");
+ LoadHMMAndPerformAction<Viterbi>(modelFile);
}
More information about the mlpack-git
mailing list