[mlpack-git] master: Add --random_initialization for mlpack_hmm_train. (3e4f3ca)
gitdub at mlpack.org
gitdub at mlpack.org
Mon May 16 15:11:56 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...986620375ce84cdc75fdfd99f63f17b5c8ee507a
>---------------------------------------------------------------
commit 3e4f3cade8a59f1a49c07e08c39fe6af33e2da06
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon May 16 15:11:56 2016 -0400
Add --random_initialization for mlpack_hmm_train.
>---------------------------------------------------------------
3e4f3cade8a59f1a49c07e08c39fe6af33e2da06
HISTORY.md | 3 ++
src/mlpack/methods/hmm/hmm_train_main.cpp | 64 ++++++++++++++++++++++++++++++-
2 files changed, 66 insertions(+), 1 deletion(-)
diff --git a/HISTORY.md b/HISTORY.md
index ab7a041..cddb28a 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -22,6 +22,9 @@
mlpack_allknn and mlpack_allkfn programs will remain as copies until mlpack
3.0.0.
+ * Add --random_initialization option to mlpack_hmm_train, for use when no
+ labels are provided.
+
### mlpack 2.0.1
###### 2016-02-04
* Fix CMake to properly detect when MKL is being used with Armadillo.
diff --git a/src/mlpack/methods/hmm/hmm_train_main.cpp b/src/mlpack/methods/hmm/hmm_train_main.cpp
index e7820a6..546fb3c 100644
--- a/src/mlpack/methods/hmm/hmm_train_main.cpp
+++ b/src/mlpack/methods/hmm/hmm_train_main.cpp
@@ -24,7 +24,9 @@ PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
"\n\n"
"The HMM is trained with the Baum-Welch algorithm if no labels are "
"provided. The tolerance of the Baum-Welch algorithm can be set with the "
- "--tolerance option."
+ "--tolerance option. In general it is a good idea to use random "
+ "initialization in this case, which can be specified with the "
+ "--random_initialization (-r) option."
"\n\n"
"Optionally, a pre-created HMM model can be used as a guess for the "
"transition matrix and emission probabilities; this is specifiable with "
@@ -47,6 +49,8 @@ PARAM_STRING("output_model_file", "File to save trained HMM to.", "o",
"output_hmm.xml");
PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
+PARAM_FLAG("random_initialization", "Initialize emissions and transition "
+ "matrices with a uniform random distribution.", "r");
using namespace mlpack;
using namespace mlpack::hmm;
@@ -296,6 +300,21 @@ int main(int argc, char** argv)
HMM<DiscreteDistribution> hmm(size_t(states),
DiscreteDistribution(maxEmission), tolerance);
+ // Initialize with random starting point.
+ if (CLI::HasParam("random_initialization"))
+ {
+ hmm.Transition().randu();
+ for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
+ hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));
+
+ for (size_t e = 0; e < hmm.Emission().size(); ++e)
+ {
+ hmm.Emission()[e].Probabilities().randu();
+ hmm.Emission()[e].Probabilities() /=
+ arma::accu(hmm.Emission()[e].Probabilities());
+ }
+ }
+
// Now train it. Pass the already-loaded training data.
Train::Apply(hmm, &trainSeq);
}
@@ -314,6 +333,22 @@ int main(int argc, char** argv)
HMM<GaussianDistribution> hmm(size_t(states),
GaussianDistribution(dimensionality), tolerance);
+ // Initialize with random starting point.
+ if (CLI::HasParam("random_initialization"))
+ {
+ hmm.Transition().randu();
+ for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
+ hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));
+
+ for (size_t e = 0; e < hmm.Emission().size(); ++e)
+ {
+ hmm.Emission()[e].Mean().randu();
+ // Generate random covariance.
+ arma::mat r = arma::randu<arma::mat>(dimensionality, dimensionality);
+ hmm.Emission()[e].Covariance(r * r.t());
+ }
+ }
+
// Now train it.
Train::Apply(hmm, &trainSeq);
}
@@ -336,6 +371,33 @@ int main(int argc, char** argv)
HMM<GMM> hmm(size_t(states), GMM(size_t(gaussians), dimensionality),
tolerance);
+ // Initialize with random starting point.
+ if (CLI::HasParam("random_initialization"))
+ {
+ hmm.Transition().randu();
+ for (size_t c = 0; c < hmm.Transition().n_cols; ++c)
+ hmm.Transition().col(c) /= arma::accu(hmm.Transition().col(c));
+
+ for (size_t e = 0; e < hmm.Emission().size(); ++e)
+ {
+ // Random weights.
+ hmm.Emission()[e].Weights().randu();
+ hmm.Emission()[e].Weights() /=
+ arma::accu(hmm.Emission()[e].Weights());
+
+ // Random means and covariances.
+ for (int g = 0; g < gaussians; ++g)
+ {
+ hmm.Emission()[e].Component(g).Mean().randu();
+
+ // Generate random covariance.
+ arma::mat r = arma::randu<arma::mat>(dimensionality,
+ dimensionality);
+ hmm.Emission()[e].Component(g).Covariance(r * r.t());
+ }
+ }
+ }
+
// Issue a warning if the user didn't give labels.
if (!CLI::HasParam("labels_file"))
Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
More information about the mlpack-git
mailing list