[mlpack-git] master: Refactor HMM programs to use boost::serialization. (7f3c228)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:05:03 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc

>---------------------------------------------------------------

commit 7f3c2286ff182b128ae667edf392d43a865072e7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Jul 12 13:34:44 2015 +0000

    Refactor HMM programs to use boost::serialization.


>---------------------------------------------------------------

7f3c2286ff182b128ae667edf392d43a865072e7
 src/mlpack/methods/hmm/CMakeLists.txt        |   2 +
 src/mlpack/methods/hmm/hmm_generate_main.cpp | 107 +++-----
 src/mlpack/methods/hmm/hmm_loglik_main.cpp   |  96 +++----
 src/mlpack/methods/hmm/hmm_train_main.cpp    | 360 ++++++++++++---------------
 src/mlpack/methods/hmm/hmm_util.hpp          |  39 +++
 src/mlpack/methods/hmm/hmm_util_impl.hpp     | 161 ++++++++++++
 src/mlpack/methods/hmm/hmm_viterbi_main.cpp  |  89 +++----
 7 files changed, 459 insertions(+), 395 deletions(-)

diff --git a/src/mlpack/methods/hmm/CMakeLists.txt b/src/mlpack/methods/hmm/CMakeLists.txt
index 179a32b..fd52249 100644
--- a/src/mlpack/methods/hmm/CMakeLists.txt
+++ b/src/mlpack/methods/hmm/CMakeLists.txt
@@ -5,6 +5,8 @@ set(SOURCES
   hmm_impl.hpp
   hmm_regression.hpp
   hmm_regression_impl.hpp
+  hmm_util.hpp
+  hmm_util_impl.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/hmm/hmm_generate_main.cpp b/src/mlpack/methods/hmm/hmm_generate_main.cpp
index 2a50d06..d0de3b0 100644
--- a/src/mlpack/methods/hmm/hmm_generate_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_generate_main.cpp
@@ -19,7 +19,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
     "parameters, saving them to the specified files (--output_file and "
     "--state_file)");
 
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
 PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
 
 PARAM_INT("start_state", "Starting state of sequence.", "t", 0);
@@ -38,92 +38,51 @@ using namespace mlpack::math;
 using namespace arma;
 using namespace std;
 
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function which can take arbitrary HMM types.
+struct Generate
 {
-  // Parse command line options.
-  CLI::ParseCommandLine(argc, argv);
-
-  // Set random seed.
-  if (CLI::GetParam<int>("seed") != 0)
-    RandomSeed((size_t) CLI::GetParam<int>("seed"));
-  else
-    RandomSeed((size_t) time(NULL));
-
-  // Load observations.
-  const string modelFile = CLI::GetParam<string>("model_file");
-  const int length = CLI::GetParam<int>("length");
-  const int startState = CLI::GetParam<int>("start_state");
-
-  if (length <= 0)
+  template<typename HMMType>
+  static void Apply(HMMType& hmm, void* /* extraInfo */)
   {
-    Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
-        << "than or equal to 0!" << endl;
-  }
+    mat observations;
+    Col<size_t> sequence;
 
-  // Load model, but first we have to determine its type.
-  SaveRestoreUtility sr;
-  sr.ReadFile(modelFile);
-  string emissionType;
-  sr.LoadParameter(emissionType, "emission_type");
-
-  mat observations;
-  Col<size_t> sequence;
-  if (emissionType == "DiscreteDistribution")
-  {
-    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-    hmm.Load(sr);
+    // Load the parameters.
+    const size_t startState = (size_t) CLI::GetParam<int>("start_state");
+    const size_t length = (size_t) CLI::GetParam<int>("length");
 
-    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
-    {
+    Log::Info << "Generating sequence of length " << length << "..." << endl;
+    if (startState >= hmm.Transition().n_rows)
       Log::Fatal << "Invalid start state (" << startState << "); must be "
           << "between 0 and number of states (" << hmm.Transition().n_rows
           << ")!" << endl;
-    }
 
-    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
-  }
-  else if (emissionType == "GaussianDistribution")
-  {
-    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-    hmm.Load(sr);
+    hmm.Generate(length, observations, sequence, startState);
 
-    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
-    {
-      Log::Fatal << "Invalid start state (" << startState << "); must be "
-          << "between 0 and number of states (" << hmm.Transition().n_rows
-          << ")!" << endl;
-    }
+    // Now save the output.
+    const string outputFile = CLI::GetParam<string>("output_file");
+    data::Save(outputFile, observations, true);
 
-    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+    // Do we want to save the hidden sequence?
+    const string sequenceFile = CLI::GetParam<string>("state_file");
+    if (sequenceFile != "")
+      data::Save(sequenceFile, sequence, true);
   }
-  else if (emissionType == "GMM")
-  {
-    HMM<GMM<> > hmm(1, GMM<>(1, 1));
-    hmm.Load(sr);
+};
 
-    if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
-    {
-      Log::Fatal << "Invalid start state (" << startState << "); must be "
-          << "between 0 and number of states (" << hmm.Transition().n_rows
-          << ")!" << endl;
-    }
+int main(int argc, char** argv)
+{
+  // Parse command line options.
+  CLI::ParseCommandLine(argc, argv);
 
-    hmm.Generate(size_t(length), observations, sequence, size_t(startState));
-  }
+  // Set random seed.
+  if (CLI::GetParam<int>("seed") != 0)
+    RandomSeed((size_t) CLI::GetParam<int>("seed"));
   else
-  {
-    Log::Fatal << "Unknown HMM type '" << emissionType << "' in file '" << modelFile
-        << "'!" << endl;
-  }
-
-  // Save observations.
-  const string outputFile = CLI::GetParam<string>("output_file");
-  data::Save(outputFile, observations, true);
-
-  // Do we want to save the hidden sequence?
-  const string sequenceFile = CLI::GetParam<string>("state_file");
-  if (sequenceFile != "")
-    data::Save(sequenceFile, sequence, true);
+    RandomSeed((size_t) time(NULL));
 
-  return 0;
+  // Load model, and perform the generation.
+  const string modelFile = CLI::GetParam<string>("model_file");
+  LoadHMMAndPerformAction<Generate>(modelFile);
 }
diff --git a/src/mlpack/methods/hmm/hmm_loglik_main.cpp b/src/mlpack/methods/hmm/hmm_loglik_main.cpp
index 6946567..2692a17 100644
--- a/src/mlpack/methods/hmm/hmm_loglik_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_loglik_main.cpp
@@ -17,7 +17,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Log-Likelihood", "This "
     "computed log-likelihood is given directly to stdout.");
 
 PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
 
 using namespace mlpack;
 using namespace mlpack::hmm;
@@ -27,74 +27,44 @@ using namespace mlpack::gmm;
 using namespace arma;
 using namespace std;
 
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Loglik
 {
-  // Parse command line options.
-  CLI::ParseCommandLine(argc, argv);
-
-  // Load observations.
-  const string inputFile = CLI::GetParam<string>("input_file");
-  const string modelFile = CLI::GetParam<string>("model_file");
-
-  mat dataSeq;
-  data::Load(inputFile, dataSeq, true);
-
-  // Load model, but first we have to determine its type.
-  SaveRestoreUtility sr;
-  sr.ReadFile(modelFile);
-  string type;
-  sr.LoadParameter(type, "hmm_type");
-
-  double loglik = 0;
-  if (type == "discrete")
+  template<typename HMMType>
+  static void Apply(HMMType& hmm, void* /* extraInfo */)
   {
-    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
-    LoadHMM(hmm, sr);
-
-    // Verify only one row in observations.
-    if (dataSeq.n_cols == 1)
-      dataSeq = trans(dataSeq);
-
-    if (dataSeq.n_rows > 1)
-      Log::Fatal << "Only one-dimensional discrete observations allowed for "
-          << "discrete HMMs!" << endl;
-
-    loglik = hmm.LogLikelihood(dataSeq);
-  }
-  else if (type == "gaussian")
-  {
-    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
-    LoadHMM(hmm, sr);
+    // Load the data sequence.
+    const string inputFile = CLI::GetParam<string>("input_file");
+    mat dataSeq;
+    data::Load(inputFile, dataSeq, true);
+
+    // Detect if we need to transpose the data, in the case where the input data
+    // has one dimension.
+    if ((dataSeq.n_cols == 1) && (hmm.Emission()[0].Dimensionality() == 1))
+    {
+      Log::Info << "Data sequence appears to be transposed; correcting."
+          << endl;
+      dataSeq = dataSeq.t();
+    }
 
-    // Verify correct dimensionality.
-    if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
-      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
-          << "does not match HMM Gaussian dimensionality ("
-          << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
-    loglik = hmm.LogLikelihood(dataSeq);
-  }
-  else if (type == "gmm")
-  {
-    HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
-    LoadHMM(hmm, sr);
-
-    // Verify correct dimensionality.
     if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
-      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
-          << "does not match HMM Gaussian dimensionality ("
+      Log::Fatal << "Dimensionality of sequence (" << dataSeq.n_rows << ") is "
+          << "not equal to the dimensionality of the HMM ("
           << hmm.Emission()[0].Dimensionality() << ")!" << endl;
 
-    loglik = hmm.LogLikelihood(dataSeq);
-  }
-  else
-  {
-    Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
-        << "'!" << endl;
+    const double loglik = hmm.LogLikelihood(dataSeq);
+
+    cout << loglik << endl;
   }
+};
 
-  cout << loglik << endl;
+int main(int argc, char** argv)
+{
+  // Parse command line options.
+  CLI::ParseCommandLine(argc, argv);
+
+  // Load model, and calculate the log-likelihood of the sequence.
+  const string modelFile = CLI::GetParam<string>("model_file");
+  LoadHMMAndPerformAction<Loglik>(modelFile);
 }
diff --git a/src/mlpack/methods/hmm/hmm_train_main.cpp b/src/mlpack/methods/hmm/hmm_train_main.cpp
index 02c9df6..87d5c04 100644
--- a/src/mlpack/methods/hmm/hmm_train_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_train_main.cpp
@@ -43,7 +43,7 @@ PARAM_INT("gaussians", "Number of gaussians in each GMM (necessary when type is"
 PARAM_STRING("model_file", "Pre-existing HMM model (optional).", "m", "");
 PARAM_STRING("labels_file", "Optional file of hidden states, used for "
     "labeled training.", "l", "");
-PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
+PARAM_STRING("output_file", "File to save trained HMM to.", "o",
     "output_hmm.xml");
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
@@ -57,6 +57,106 @@ using namespace mlpack::math;
 using namespace arma;
 using namespace std;
 
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Train
+{
+  template<typename HMMType>
+  static void Apply(HMMType& hmm, vector<mat>* trainSeqPtr)
+  {
+    const bool batch = CLI::HasParam("batch");
+    const double tolerance = CLI::GetParam<double>("tolerance");
+
+    // Do we need to replace the tolerance?
+    if (CLI::HasParam("tolerance"))
+      hmm.Tolerance() = tolerance;
+
+    const string labelsFile = CLI::GetParam<string>("labels_file");
+
+    // Verify that the dimensionality of our observations is the same as the
+    // dimensionality of our HMM's emissions.
+    vector<mat>& trainSeq = *trainSeqPtr;
+    for (size_t i = 0; i < trainSeq.size(); ++i)
+      if (trainSeq[i].n_rows != hmm.Emission()[0].Dimensionality())
+        Log::Fatal << "Dimensionality of training sequence " << i << " ("
+            << trainSeq[i].n_rows << ") is not equal to the dimensionality of "
+            << "the HMM (" << hmm.Emission()[0].Dimensionality() << ")!"
+            << endl;
+
+    vector<arma::Col<size_t> > labelSeq; // May be empty.
+    if (labelsFile != "")
+    {
+      // Do we have multiple label files to load?
+      char lineBuf[1024];
+      if (batch)
+      {
+        fstream f(labelsFile);
+
+        if (!f.is_open())
+          Log::Fatal << "Could not open '" << labelsFile << "' for reading."
+              << endl;
+
+        // Now read each line in.
+        f.getline(lineBuf, 1024, '\n');
+        while (!f.eof())
+        {
+          Log::Info << "Adding training sequence labels from '" << lineBuf
+              << "'." << endl;
+
+          // Now read the matrix.
+          Mat<size_t> label;
+          data::Load(lineBuf, label, true); // Fatal on failure.
+
+          // Ensure that matrix only has one column.
+          if (label.n_rows == 1)
+            label = trans(label);
+
+          if (label.n_cols > 1)
+            Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+          labelSeq.push_back(label.col(0));
+
+          f.getline(lineBuf, 1024, '\n');
+        }
+
+        f.close();
+      }
+      else
+      {
+        Mat<size_t> label;
+        data::Load(labelsFile, label, true);
+
+        // Ensure that matrix only has one column.
+        if (label.n_rows == 1)
+          label = trans(label);
+
+        if (label.n_cols > 1)
+          Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+        // Verify the same number of observations as the data.
+        if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
+          Log::Fatal << "Label sequence " << labelSeq.size() << " does not have"
+              << " the same number of points as observation sequence "
+              << labelSeq.size() << "!" << endl;
+
+        labelSeq.push_back(label.col(0));
+      }
+
+      // Now perform the training with labels.
+      hmm.Train(trainSeq, labelSeq);
+    }
+    else
+    {
+      // Perform unsupervised training.
+      hmm.Train(trainSeq);
+    }
+
+    // Save the model.
+    const string modelFile = CLI::GetParam<string>("model_file");
+    SaveHMM(hmm, modelFile);
+  }
+};
+
 int main(int argc, char** argv)
 {
   // Parse command line options.
@@ -69,31 +169,33 @@ int main(int argc, char** argv)
     RandomSeed((size_t) time(NULL));
 
   // Validate parameters.
-  const string inputFile = CLI::GetParam<string>("input_file");
-  const string labelsFile = CLI::GetParam<string>("labels_file");
   const string modelFile = CLI::GetParam<string>("model_file");
-  const string outputFile = CLI::GetParam<string>("output_file");
+  const string inputFile = CLI::GetParam<string>("input_file");
   const string type = CLI::GetParam<string>("type");
-  const int states = CLI::GetParam<int>("states");
-  const bool batch = CLI::HasParam("batch");
+  const size_t states = CLI::GetParam<int>("states");
   const double tolerance = CLI::GetParam<double>("tolerance");
+  const bool batch = CLI::HasParam("batch");
 
-  // Validate number of states.
-  if (states == 0 && modelFile == "")
-  {
-    Log::Fatal << "Must specify number of states if model file is not "
-        << "specified!" << endl;
-  }
+  // Verify that either a model or a type was given.
+  if (modelFile == "" && type == "")
+    Log::Fatal << "No model file specified and no HMM type given!  At least "
+        << "one is required." << endl;
 
-  if (states < 0 && modelFile == "")
+  // If no model is specified, make sure we are training with valid parameters.
+  if (modelFile == "")
   {
-    Log::Fatal << "Invalid number of states (" << states << "); must be greater"
-        << " than or equal to 1." << endl;
+    // Validate number of states.
+    if (states == 0)
+      Log::Fatal << "Must specify number of states if model file is not "
+          << "specified!" << endl;
   }
 
-  // Load the dataset(s) and labels.
+  if (modelFile != "" && CLI::HasParam("tolerance"))
+    Log::Info << "Tolerance of existing model in '" << modelFile << "' will be "
+        << "replaced with specified tolerance of " << tolerance << "." << endl;
+
+  // Load the input data.
   vector<mat> trainSeq;
-  vector<arma::Col<size_t> > labelSeq; // May be empty.
   if (batch)
   {
     // The input file contains a list of files to read.
@@ -103,30 +205,20 @@ int main(int argc, char** argv)
     fstream f(inputFile.c_str(), ios_base::in);
 
     if (!f.is_open())
-      Log::Fatal << "Could not open '" << inputFile << "' for reading." << endl;
+      Log::Fatal << "Could not open '" << inputFile << "' for reading."
+          << endl;
 
     // Now read each line in.
-    char lineBuf[1024]; // Max 1024 characters... hopefully that is long enough.
+    char lineBuf[1024]; // Max 1024 characters... hopefully long enough.
     f.getline(lineBuf, 1024, '\n');
     while (!f.eof())
     {
-      Log::Info << "Adding training sequence from '" << lineBuf << "'." << endl;
+      Log::Info << "Adding training sequence from '" << lineBuf << "'."
+          << endl;
 
       // Now read the matrix.
       trainSeq.push_back(mat());
-      if (labelsFile == "") // Nonfatal in this case.
-      {
-        if (!data::Load(lineBuf, trainSeq.back(), false))
-        {
-          Log::Warn << "Loading training sequence from '" << lineBuf << "' "
-              << "failed.  Sequence ignored." << endl;
-          trainSeq.pop_back(); // Remove last element which we did not use.
-        }
-      }
-      else
-      {
-        data::Load(lineBuf, trainSeq.back(), true);
-      }
+      data::Load(lineBuf, trainSeq.back(), true); // Fatal on failure.
 
       // See if we need to transpose the data.
       if (type == "discrete")
@@ -139,89 +231,25 @@ int main(int argc, char** argv)
     }
 
     f.close();
-
-    // Now load labels, if we need to.
-    if (labelsFile != "")
-    {
-      f.open(labelsFile.c_str(), ios_base::in);
-
-      if (!f.is_open())
-        Log::Fatal << "Could not open '" << labelsFile << "' for reading."
-            << endl;
-
-      // Now read each line in.
-      f.getline(lineBuf, 1024, '\n');
-      while (!f.eof())
-      {
-        Log::Info << "Adding training sequence labels from '" << lineBuf
-            << "'." << endl;
-
-        // Now read the matrix.
-        Mat<size_t> label;
-        data::Load(lineBuf, label, true); // Fatal on failure.
-
-        // Ensure that matrix only has one column.
-        if (label.n_rows == 1)
-          label = trans(label);
-
-        if (label.n_cols > 1)
-          Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
-        labelSeq.push_back(label.col(0));
-
-        f.getline(lineBuf, 1024, '\n');
-      }
-    }
   }
   else
   {
     // Only one input file.
     trainSeq.resize(1);
     data::Load(inputFile, trainSeq[0], true);
-
-    // Do we need to load labels?
-    if (labelsFile != "")
-    {
-      Mat<size_t> label;
-      data::Load(labelsFile, label, true);
-
-      // Ensure that matrix only has one column.
-      if (label.n_rows == 1)
-        label = trans(label);
-
-      if (label.n_cols > 1)
-        Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
-      // Verify the same number of observations as the data.
-      if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
-        Log::Fatal << "Label sequence " << labelSeq.size() << " does not have "
-            << "the same number of points as observation sequence "
-            << labelSeq.size() << "!" << endl;
-
-      labelSeq.push_back(label.col(0));
-    }
   }
 
-  // Now, train the HMM, since we have loaded the input data.
-  if (type == "discrete")
+  // If we have a model file, we can autodetect the type.
+  if (modelFile != "")
   {
-    // Verify observations are valid.
-    for (size_t i = 0; i < trainSeq.size(); ++i)
-      if (trainSeq[i].n_rows > 1)
-        Log::Fatal << "Error in training sequence " << i << ": only "
-            << "one-dimensional discrete observations allowed for discrete "
-            << "HMMs!" << endl;
-
-    // Do we have a model to preload?
-    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1), tolerance);
+    LoadHMMAndPerformAction<Train>(modelFile, &trainSeq);
+  }
+  else
+  {
+    // We need to read in the type and build the HMM by hand.
+    const string type = CLI::GetParam<string>("type");
 
-    if (modelFile != "")
-    {
-      SaveRestoreUtility loader;
-      loader.ReadFile(modelFile);
-      LoadHMM(hmm, loader);
-    }
-    else // New model.
+    if (type == "discrete")
     {
       // Maximum observation is necessary so we know how to train the discrete
       // distribution.
@@ -238,84 +266,34 @@ int main(int argc, char** argv)
           << endl;
 
       // Create HMM object.
-      hmm = HMM<DiscreteDistribution>(size_t(states),
+      HMM<DiscreteDistribution> hmm(size_t(states),
           DiscreteDistribution(maxEmission), tolerance);
-    }
-
-    // Do we have labels?
-    if (labelsFile == "")
-      hmm.Train(trainSeq); // Unsupervised training.
-    else
-      hmm.Train(trainSeq, labelSeq); // Supervised training.
-
-    // Finally, save the model.  This should later be integrated into the HMM
-    // class itself.
-    SaveRestoreUtility sr;
-    SaveHMM(hmm, sr);
-    sr.WriteFile(outputFile);
-  }
-  else if (type == "gaussian")
-  {
-    // Create HMM object.
-    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1), tolerance);
 
-    // Do we have a model to load?
-    size_t dimensionality = 0;
-    if (modelFile != "")
-    {
-      SaveRestoreUtility loader;
-      loader.ReadFile(modelFile);
-      LoadHMM(hmm, loader);
-
-      dimensionality = hmm.Emission()[0].Mean().n_elem;
+      // Now train it.  Pass the already-loaded training data.
+      Train::Apply(hmm, &trainSeq);
     }
-    else
+    else if (type == "gaussian")
     {
       // Find dimension of the data.
-      dimensionality = trainSeq[0].n_rows;
-
-      hmm = HMM<GaussianDistribution>(size_t(states),
-          GaussianDistribution(dimensionality), tolerance);
-    }
-
-    // Verify dimensionality of data.
-    for (size_t i = 0; i < trainSeq.size(); ++i)
-      if (trainSeq[i].n_rows != dimensionality)
-        Log::Fatal << "Observation sequence " << i << " dimensionality ("
-            << trainSeq[i].n_rows << " is incorrect (should be "
-            << dimensionality << ")!" << endl;
-
-    // Now run the training.
-    if (labelsFile == "")
-      hmm.Train(trainSeq); // Unsupervised training.
-    else
-      hmm.Train(trainSeq, labelSeq); // Supervised training.
+      const size_t dimensionality = trainSeq[0].n_rows;
 
-    // Finally, save the model.  This should later be integrated into th HMM
-    // class itself.
-    SaveRestoreUtility sr;
-    SaveHMM(hmm, sr);
-    sr.WriteFile(outputFile);
-  }
-  else if (type == "gmm")
-  {
-    // Create HMM object.
-    HMM<GMM<> > hmm(1, GMM<>(1, 1));
+      // Verify dimensionality of data.
+      for (size_t i = 0; i < trainSeq.size(); ++i)
+        if (trainSeq[i].n_rows != dimensionality)
+          Log::Fatal << "Observation sequence " << i << " dimensionality ("
+              << trainSeq[i].n_rows << " is incorrect (should be "
+              << dimensionality << ")!" << endl;
 
-    // Do we have a model to load?
-    size_t dimensionality = 0;
-    if (modelFile != "")
-    {
-      SaveRestoreUtility loader;
-      loader.ReadFile(modelFile);
-      LoadHMM(hmm, loader);
+      HMM<GaussianDistribution> hmm(size_t(states),
+          GaussianDistribution(dimensionality), tolerance);
 
-      dimensionality = hmm.Emission()[0].Dimensionality();
+      // Now train it.
+      Train::Apply(hmm, &trainSeq);
     }
-    else
+    else if (type == "gmm")
     {
       // Find dimension of the data.
-      dimensionality = trainSeq[0].n_rows;
+      const size_t dimensionality = trainSeq[0].n_rows;
 
       const int gaussians = CLI::GetParam<int>("gaussians");
 
@@ -327,37 +305,21 @@ int main(int argc, char** argv)
         Log::Fatal << "Invalid number of gaussians (" << gaussians << "); must "
             << "be greater than or equal to 1." << endl;
 
-      hmm = HMM<GMM<> >(size_t(states), GMM<>(size_t(gaussians),
-          dimensionality), tolerance);
-    }
+      // Create HMM object.
+      HMM<GMM<>> hmm(size_t(states), GMM<>(size_t(gaussians), dimensionality),
+          tolerance);
 
-    // Verify dimensionality of data.
-    for (size_t i = 0; i < trainSeq.size(); ++i)
-      if (trainSeq[i].n_rows != dimensionality)
-        Log::Fatal << "Observation sequence " << i << " dimensionality ("
-            << trainSeq[i].n_rows << " is incorrect (should be "
-            << dimensionality << ")!" << endl;
+      // Issue a warning if the user didn't give labels.
+      if (!CLI::HasParam("label_file"))
+        Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
+            << "going to produce good results!" << endl;
 
-    // Now run the training.
-    if (labelsFile == "")
-    {
-      Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
-          << "going to produce good results!" << endl;
-      hmm.Train(trainSeq);
+      Train::Apply(hmm, &trainSeq);
     }
     else
     {
-      hmm.Train(trainSeq, labelSeq);
+      Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
+          << "'gaussian', or 'gmm'." << endl;
     }
-
-    // Save model.
-    SaveRestoreUtility sr;
-    SaveHMM(hmm, sr);
-    sr.WriteFile(outputFile);
-  }
-  else
-  {
-    Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
-        << "'gaussian', or 'gmm'." << endl;
   }
 }
diff --git a/src/mlpack/methods/hmm/hmm_util.hpp b/src/mlpack/methods/hmm/hmm_util.hpp
new file mode 100644
index 0000000..46ec818
--- /dev/null
+++ b/src/mlpack/methods/hmm/hmm_util.hpp
@@ -0,0 +1,39 @@
+/**
+ * @file hmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Utility to read HMM type from file.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+//! HMMType, to be stored on disk.  This is of type char, which is one byte.
+//! (I'm not sure what will happen on systems where one byte is not eight bits.)
+enum HMMType : char
+{
+  DiscreteHMM = 0,
+  GaussianHMM,
+  GaussianMixtureModelHMM
+};
+
+//! ActionType should implement static void Apply(HMMType&).
+template<typename ActionType, typename ExtraInfoType = void>
+void LoadHMMAndPerformAction(const std::string& modelFile,
+                             ExtraInfoType* x = NULL);
+
+//! Save an HMM to a file.  The file must also encode what type of HMM is being
+//! stored.
+template<typename HMMType>
+void SaveHMM(HMMType& hmm, const std::string& modelFile);
+
+} // namespace hmm
+} // namespace mlpack
+
+#include "hmm_util_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/hmm/hmm_util_impl.hpp b/src/mlpack/methods/hmm/hmm_util_impl.hpp
new file mode 100644
index 0000000..e1b2de7
--- /dev/null
+++ b/src/mlpack/methods/hmm/hmm_util_impl.hpp
@@ -0,0 +1,161 @@
+/**
+ * @file hmm_util_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of HMM utilities to load arbitrary HMM types.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/methods/hmm/hmm.hpp>
+#include <mlpack/methods/gmm/gmm.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+// Forward declarations of utility functions.
+
+// Set up the archive for deserialization.
+template<typename ActionType, typename ArchiveType, typename ExtraInfoType>
+void LoadHMMAndPerformActionHelper(const std::string& modelFile,
+                                   ExtraInfoType* x = NULL);
+
+// Actually deserialize into the correct type.
+template<typename ActionType,
+         typename ArchiveType,
+         typename HMMType,
+         typename ExtraInfoType>
+void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x = NULL);
+
+template<typename ActionType, typename ExtraInfoType>
+void LoadHMMAndPerformAction(const std::string& modelFile,
+                             ExtraInfoType* x)
+{
+  using namespace boost::archive;
+
+  const std::string extension = data::Extension(modelFile);
+  if (extension == "xml")
+    LoadHMMAndPerformActionHelper<ActionType, xml_iarchive>(modelFile, x);
+  else if (extension == "bin")
+    LoadHMMAndPerformActionHelper<ActionType, binary_iarchive>(modelFile, x);
+  else if (extension == "txt")
+    LoadHMMAndPerformActionHelper<ActionType, text_iarchive>(modelFile, x);
+  else
+    Log::Fatal << "Unknown extension '" << extension << "' for HMM model file "
+        << "(known: 'xml', 'txt', 'bin')." << std::endl;
+}
+
+template<typename ActionType,
+         typename ArchiveType,
+         typename ExtraInfoType>
+void LoadHMMAndPerformActionHelper(const std::string& modelFile,
+                                   ExtraInfoType* x)
+{
+  std::ifstream ifs(modelFile);
+  ArchiveType ar(ifs);
+
+  // Read in the unsigned integer that denotes the type of the model.
+  char type;
+  ar >> data::CreateNVP(type, "hmm_type");
+
+  using namespace mlpack::distribution;
+
+  switch (type)
+  {
+    case HMMType::DiscreteHMM:
+      DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+          HMM<DiscreteDistribution>>(ar, x);
+
+    case HMMType::GaussianHMM:
+      DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+          HMM<GaussianDistribution>>(ar, x);
+
+    case HMMType::GaussianMixtureModelHMM:
+      DeserializeHMMAndPerformAction<ActionType, ArchiveType,
+          HMM<gmm::GMM<>>>(ar, x);
+
+    default:
+      Log::Fatal << "Unknown HMM type '" << (unsigned int) type << "'!"
+          << std::endl;
+  }
+}
+
+template<typename ActionType,
+         typename ArchiveType,
+         typename HMMType,
+         typename ExtraInfoType>
+void DeserializeHMMAndPerformAction(ArchiveType& ar, ExtraInfoType* x)
+{
+  // Extract the HMM and perform the action.
+  HMMType hmm;
+  ar >> data::CreateNVP(hmm, "hmm");
+  ActionType::Apply(hmm, x);
+}
+
+// Helper function.
+template<typename ArchiveType, typename HMMType>
+void SaveHMMHelper(HMMType& hmm, const std::string& modelFile);
+
+template<typename HMMType>
+char GetHMMType();
+
+template<typename HMMType>
+void SaveHMM(HMMType& hmm, const std::string& modelFile)
+{
+  using namespace boost::archive;
+
+  const std::string extension = data::Extension(modelFile);
+  if (extension == "xml")
+    SaveHMMHelper<xml_oarchive>(hmm, modelFile);
+  else if (extension == "bin")
+    SaveHMMHelper<binary_oarchive>(hmm, modelFile);
+  else if (extension == "txt")
+    SaveHMMHelper<text_oarchive>(hmm, modelFile);
+  else
+    Log::Fatal << "Unknown extension '" << extension << "' for HMM model file."
+        << std::endl;
+}
+
+template<typename ArchiveType, typename HMMType>
+void SaveHMMHelper(HMMType& hmm, const std::string& modelFile)
+{
+  std::ofstream ofs(modelFile);
+  ArchiveType ar(ofs);
+
+  // Write out the unsigned integer that denotes the type of the model.
+  char type = GetHMMType<HMMType>();
+  if (type == char(-1))
+    Log::Fatal << "Unknown HMM type given to SaveHMM()!" << std::endl;
+
+  ar << data::CreateNVP(type, "hmm_type");
+  ar << data::CreateNVP(hmm, "hmm");
+}
+
+// Utility functions to turn a type into something we can store.
+template<typename HMMType>
+char GetHMMType() { return char(-1); }
+
+template<>
+char GetHMMType<HMM<distribution::DiscreteDistribution>>()
+{
+  return HMMType::DiscreteHMM;
+}
+
+template<>
+char GetHMMType<HMM<distribution::GaussianDistribution>>()
+{
+  return HMMType::GaussianHMM;
+}
+
+template<>
+char GetHMMType<HMM<gmm::GMM<>>>()
+{
+  return HMMType::GaussianMixtureModelHMM;
+}
+
+} // namespace hmm
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/hmm/hmm_viterbi_main.cpp b/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
index de13bee..34933af 100644
--- a/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
@@ -19,7 +19,7 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Viterbi State Prediction", "This "
     "is saved to the specified output file (--output_file).");
 
 PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING_REQ("model_file", "File containing HMM.", "m");
 PARAM_STRING("output_file", "File to save predicted state sequence to.", "o",
     "output.csv");
 
@@ -31,60 +31,26 @@ using namespace mlpack::gmm;
 using namespace arma;
 using namespace std;
 
-int main(int argc, char** argv)
+// Because we don't know what the type of our HMM is, we need to write a
+// function that can take arbitrary HMM types.
+struct Viterbi
 {
-  // Parse command line options.
-  CLI::ParseCommandLine(argc, argv);
-
-  // Load observations.
-  const string inputFile = CLI::GetParam<string>("input_file");
-  const string modelFile = CLI::GetParam<string>("model_file");
-
-  mat dataSeq;
-  data::Load(inputFile, dataSeq, true);
-
-  // Load model, but first we have to determine its type.
-  SaveRestoreUtility sr;
-  sr.ReadFile(modelFile);
-  string type;
-  sr.LoadParameter(type, "hmm_type");
-
-  arma::Col<size_t> sequence;
-  if (type == "discrete")
+  template<typename HMMType>
+  static void Apply(HMMType& hmm, void* /* extraInfo */)
   {
-    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
-    LoadHMM(hmm, sr);
-
-    // Verify only one row in observations.
-    if (dataSeq.n_cols == 1)
-      dataSeq = trans(dataSeq);
-
-    if (dataSeq.n_rows > 1)
-      Log::Fatal << "Only one-dimensional discrete observations allowed for "
-          << "discrete HMMs!" << endl;
-
-    hmm.Predict(dataSeq, sequence);
-  }
-  else if (type == "gaussian")
-  {
-    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
-    LoadHMM(hmm, sr);
+    // Load observations.
+    const string inputFile = CLI::GetParam<string>("input_file");
 
-    // Verify correct dimensionality.
-    if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
-      Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
-          << "does not match HMM Gaussian dimensionality ("
-          << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
-    hmm.Predict(dataSeq, sequence);
-  }
-  else if (type == "gmm")
-  {
-    HMM<GMM<> > hmm(1, GMM<>(1, 1));
+    mat dataSeq;
+    data::Load(inputFile, dataSeq, true);
 
-    LoadHMM(hmm, sr);
+    // See if transposing the data could make it the right dimensionality.
+    if ((dataSeq.n_cols == 1) && (hmm.Emission()[0].Dimensionality() == 1))
+    {
+      Log::Info << "Data sequence appears to be transposed; correcting."
+          << endl;
+      dataSeq = dataSeq.t();
+    }
 
     // Verify correct dimensionality.
     if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
@@ -92,15 +58,20 @@ int main(int argc, char** argv)
           << "does not match HMM Gaussian dimensionality ("
           << hmm.Emission()[0].Dimensionality() << ")!" << endl;
 
+    arma::Col<size_t> sequence;
     hmm.Predict(dataSeq, sequence);
+
+    // Save output.
+    const string outputFile = CLI::GetParam<string>("output_file");
+    data::Save(outputFile, sequence, true);
   }
-  else
-  {
-    Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
-        << "'!" << endl;
-  }
+};
+
+int main(int argc, char** argv)
+{
+  // Parse command line options.
+  CLI::ParseCommandLine(argc, argv);
 
-  // Save output.
-  const string outputFile = CLI::GetParam<string>("output_file");
-  data::Save(outputFile, sequence, true);
+  const string modelFile = CLI::GetParam<string>("model_file");
+  LoadHMMAndPerformAction<Viterbi>(modelFile);
 }



More information about the mlpack-git mailing list