[mlpack-svn] r10261 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Nov 12 19:29:57 EST 2011
Author: rcurtin
Date: 2011-11-12 19:29:57 -0500 (Sat, 12 Nov 2011)
New Revision: 10261
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
Change the names of a few methods, and implement supervised training/estimation.
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-11-12 22:54:22 UTC (rev 10260)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-11-13 00:29:57 UTC (rev 10261)
@@ -53,13 +53,26 @@
HMM(const arma::mat& transition, const arma::mat& emission);
/**
- * Estimate the transition and emission matrices.
+ * Train the model using the Baum-Welch algorithm, with only the given
+ * unlabeled observations. Instead of giving a guess transition and emission
+ * matrix here, do that in the constructor.
+ *
+ * @param dataSeq Vector of observation sequences.
*/
- void EstimateModel(const std::vector<arma::vec>& data_seq);
- void EstimateModel(const std::vector<arma::vec>& data_seq,
- const std::vector<arma::vec>& state_seq);
+ void Train(const std::vector<arma::vec>& dataSeq);
/**
+ * Train the model using the given labeled observations; the transition and
+ * emission matrices are directly estimated.
+ *
+ * @param dataSeq Vector of observation sequences.
+ * @param stateSeq Vector of state sequences, corresponding to each
+ * observation.
+ */
+ void Train(const std::vector<arma::vec>& dataSeq,
+ const std::vector<arma::Col<size_t> >& stateSeq);
+
+ /**
* Generate a random data sequence of the given length. The data sequence is
* stored in the data_sequence parameter, and the state sequence is stored in
* the state_sequence parameter.
@@ -88,10 +101,16 @@
double LogLikelihood(const arma::vec& data_seq) const;
/**
- * Compute the most probable hidden state sequence for a given data sequence.
- * Needs a better name.
+ * Compute the most probable hidden state sequence for the given data
+ * sequence, using the Viterbi algorithm, returning the log-likelihood of the
+ * most likely state sequence.
+ *
+ * @param dataSeq Sequence of observations.
+ * @param stateSeq Vector in which the most probable state sequence will be
+ * stored.
+ * @return Log-likelihood of most probable state sequence.
*/
- double Viterbi(const arma::vec& data_seq, arma::Col<size_t>& state_seq) const;
+ double Predict(const arma::vec& data_seq, arma::Col<size_t>& stateSeq) const;
/**
* Return the transition matrix.
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-12 22:54:22 UTC (rev 10260)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-13 00:29:57 UTC (rev 10261)
@@ -14,21 +14,32 @@
namespace mlpack {
namespace hmm {
+/**
+ * Create the Hidden Markov Model with the given number of hidden states and the
+ * given number of emission states.
+ */
template<typename Distribution>
HMM<Distribution>::HMM(const size_t states, const size_t emissions) :
transition(arma::ones<arma::mat>(states, states) / (double) states),
emission(arma::ones<arma::mat>(emissions, states) / (double) emissions)
{ /* nothing to do */ }
+/**
+ * Create the Hidden Markov Model with the given transition matrix and the given
+ * emission probability matrix.
+ */
template<typename Distribution>
HMM<Distribution>::HMM(const arma::mat& transition, const arma::mat& emission) :
transition(transition),
emission(emission)
{ /* nothing to do */ }
-// Generate the model if we only know the observations.
+/**
+ * Train the model using the Baum-Welch algorith, with only the given unlabeled
+ * observations.
+ */
template<typename Distribution>
-void HMM<Distribution>::EstimateModel(const std::vector<arma::vec>& data_seq)
+void HMM<Distribution>::Train(const std::vector<arma::vec>& dataSeq)
{
// We should allow a guess at the transition and emission matrices.
@@ -53,7 +64,7 @@
loglik = 0;
// Loop over each sequence.
- for (size_t seq = 0; seq < data_seq.size(); seq++)
+ for (size_t seq = 0; seq < dataSeq.size(); seq++)
{
arma::mat stateProb;
arma::mat forward;
@@ -61,28 +72,28 @@
arma::vec scales;
// Add the log-likelihood of this sequence. This is the E-step.
- loglik += Estimate(data_seq[seq], stateProb, forward, backward, scales);
+ loglik += Estimate(dataSeq[seq], stateProb, forward, backward, scales);
// Now re-estimate the parameters. This is the M-step.
// T_ij = sum_d ((1 / P(seq[d])) sum_t (f(i, t) T_ij E_i(seq[d][t]) b(i,
// t + 1)))
// E_ij = sum_d ((1 / P(seq[d])) sum_{t | seq[d][t] = j} f(i, t) b(i, t)
// We store the new estimates in a different matrix.
- for (size_t t = 0; t < data_seq[seq].n_elem; t++)
+ for (size_t t = 0; t < dataSeq[seq].n_elem; t++)
{
for (size_t j = 0; j < transition.n_cols; j++)
{
- if (t < data_seq[seq].n_elem - 1)
+ if (t < dataSeq[seq].n_elem - 1)
{
// Estimate of T_ij (probability of transition from state j to state
// i). We postpone multiplication of the old T_ij until later.
for (size_t i = 0; i < transition.n_rows; i++)
newTransition(i, j) += forward(j, t) * backward(i, t + 1) *
- emission((size_t) data_seq[seq][t + 1], i) / scales[t + 1];
+ emission((size_t) dataSeq[seq][t + 1], i) / scales[t + 1];
}
// Estimate of E_ij (probability of emission i while in state j).
- newEmission((size_t) data_seq[seq][t], j) += stateProb(j, t);
+ newEmission((size_t) dataSeq[seq][t], j) += stateProb(j, t);
}
}
}
@@ -113,12 +124,57 @@
}
}
-// Generate the model.
+/**
+ * Train the model using the given labeled observations; the transition and
+ * emission matrices are directly estimated.
+ */
template<typename Distribution>
-void HMM<Distribution>::EstimateModel(const std::vector<arma::vec>& data_seq,
- const std::vector<arma::vec>& state_seq)
+void HMM<Distribution>::Train(const std::vector<arma::vec>& dataSeq,
+ const std::vector<arma::Col<size_t> >& stateSeq)
{
+ // Simple error checking.
+ if (dataSeq.size() != stateSeq.size())
+ Log::Fatal << "HMM::Train(): number of data sequences not equal to number "
+ "of state sequences." << std::endl;
+ transition.zeros();
+ emission.zeros();
+
+ // Estimate the transition and emission matrices directly from the
+ // observations.
+ for (size_t seq = 0; seq < dataSeq.size(); seq++)
+ {
+ // Simple error checking.
+ if (dataSeq[seq].n_elem != stateSeq[seq].n_elem)
+ Log::Fatal << "HMM::Train(): number of observations in sequence " << seq
+ << " not equal to number of states" << std::endl;
+
+ // Loop over each observation in the sequence. For estimation of the
+ // transition matrix, we must ignore the last observation.
+ for (size_t t = 0; t < dataSeq[seq].n_elem - 1; t++)
+ {
+ transition(stateSeq[seq][t + 1], stateSeq[seq][t])++;
+ emission(dataSeq[seq][t], stateSeq[seq][t])++;
+ }
+
+ // Last observation.
+ emission(dataSeq[seq][dataSeq[seq].n_elem - 1],
+ stateSeq[seq][stateSeq[seq].n_elem - 1])++;
+ }
+
+ // Normalize transition matrix and emission matrix..
+ for (size_t col = 0; col < transition.n_cols; col++)
+ {
+ // If the transition probability sum is greater than 0 in this column, the
+ // emission probability sum will also be greater than 0. We want to avoid
+ // division by 0.
+ double sum = accu(transition.col(col));
+ if (sum > 0)
+ {
+ transition.col(col) /= sum;
+ emission.col(col) /= accu(emission.col(col));
+ }
+ }
}
/**
@@ -145,11 +201,12 @@
}
/**
- * Compute the most probable hidden state sequence for the given observation.
- * Returns the log-likelihood of the most likely sequence.
+ * Compute the most probable hidden state sequence for the given observation
+ * using the Viterbi algorithm. Returns the log-likelihood of the most likely
+ * sequence.
*/
template<typename Distribution>
-double HMM<Distribution>::Viterbi(const arma::vec& dataSeq,
+double HMM<Distribution>::Predict(const arma::vec& dataSeq,
arma::Col<size_t>& stateSeq) const
{
// This is an implementation of the Viterbi algorithm for finding the most
More information about the mlpack-svn
mailing list