[mlpack-svn] r14907 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 15 17:16:50 EDT 2013


Author: rcurtin
Date: 2013-04-15 17:16:50 -0400 (Mon, 15 Apr 2013)
New Revision: 14907

Modified:
   mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp
Log:
Add option for tolerance of Baum-Welch algorithm.


Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp	2013-04-15 21:16:37 UTC (rev 14906)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_train_main.cpp	2013-04-15 21:16:50 UTC (rev 14907)
@@ -14,14 +14,18 @@
 PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
     "Hidden Markov Model to be trained on labeled or unlabeled data.  It "
     "support three types of HMMs: discrete HMMs, Gaussian HMMs, or GMM HMMs."
-    "\n"
+    "\n\n"
     "Either one input sequence can be specified (with --input_file), or, a "
     "file containing files in which input sequences can be found (when "
     "--input_file and --batch are used together).  In addition, labels can be "
     "provided in the file specified by --label_file, and if --batch is used, "
     "the file given to --label_file should contain a list of files of labels "
-    "corresponding to the sequences in the file given to --input_file.\n"
-    "\n"
+    "corresponding to the sequences in the file given to --input_file."
+    "\n\n"
+    "The HMM is trained with the Baum-Welch algorithm if no labels are "
+    "provided.  The tolerance of the Baum-Welch algorithm can be set with the "
+    "--tolerance option."
+    "\n\n"
     "Optionally, a pre-created HMM model can be used as a guess for the "
     "transition matrix and emission probabilities; this is specifiable with "
     "--model_file.");
@@ -42,6 +46,7 @@
 PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
     "output_hmm.xml");
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
 
 using namespace mlpack;
 using namespace mlpack::hmm;
@@ -71,6 +76,7 @@
   const string type = CLI::GetParam<string>("type");
   const int states = CLI::GetParam<int>("states");
   const bool batch = CLI::HasParam("batch");
+  const double tolerance = CLI::GetParam<double>("tolerance");
 
   // Validate number of states.
   if (states == 0 && modelFile == "")
@@ -207,7 +213,7 @@
             << "HMMs!" << endl;
 
     // Do we have a model to preload?
-    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+    HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1), tolerance);
 
     if (modelFile != "")
     {
@@ -233,7 +239,7 @@
 
       // Create HMM object.
       hmm = HMM<DiscreteDistribution>(size_t(states),
-          DiscreteDistribution(maxEmission));
+          DiscreteDistribution(maxEmission), tolerance);
     }
 
     // Do we have labels?
@@ -251,7 +257,7 @@
   else if (type == "gaussian")
   {
     // Create HMM object.
-    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+    HMM<GaussianDistribution> hmm(1, GaussianDistribution(1), tolerance);
 
     // Do we have a model to load?
     size_t dimensionality = 0;
@@ -269,7 +275,7 @@
       dimensionality = trainSeq[0].n_rows;
 
       hmm = HMM<GaussianDistribution>(size_t(states),
-          GaussianDistribution(dimensionality));
+          GaussianDistribution(dimensionality), tolerance);
     }
 
     // Verify dimensionality of data.
@@ -322,7 +328,7 @@
             << "be greater than or equal to 1." << endl;
 
       hmm = HMM<GMM<> >(size_t(states), GMM<>(size_t(gaussians),
-          dimensionality));
+          dimensionality), tolerance);
     }
 
     // Verify dimensionality of data.




More information about the mlpack-svn mailing list