[mlpack-svn] r10257 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Nov 12 14:17:11 EST 2011


Author: rcurtin
Date: 2011-11-12 14:17:10 -0500 (Sat, 12 Nov 2011)
New Revision: 10257

Added:
   mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
   mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/hmm/discreteDST.cpp
   mlpack/trunk/src/mlpack/methods/hmm/discreteDST.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/hmm/discreteHMM.cpp
Log:
Modify HMM code.  We'll make a new class, HMM<Distribution>, which will serve
the purpose of all three of the previous classes.


Modified: mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt	2011-11-12 19:16:25 UTC (rev 10256)
+++ mlpack/trunk/src/mlpack/methods/hmm/CMakeLists.txt	2011-11-12 19:17:10 UTC (rev 10257)
@@ -13,6 +13,8 @@
   mixtureDST.cpp
   support.hpp
   support.cpp
+  hmm.hpp
+  hmm_impl.hpp
 )
 
 # Add directory name to sources.

Deleted: mlpack/trunk/src/mlpack/methods/hmm/discreteDST.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/discreteDST.cpp	2011-11-12 19:16:25 UTC (rev 10256)
+++ mlpack/trunk/src/mlpack/methods/hmm/discreteDST.cpp	2011-11-12 19:17:10 UTC (rev 10257)
@@ -1,55 +0,0 @@
-#include <mlpack/core.h>
-#include "discreteDST.hpp"
-#include "support.hpp"
-
-namespace mlpack {
-namespace hmm {
-
-void DiscreteDST::Init(int n) {
-  p.set_size(n);
-  acc_p.set_size(n);
-
-  double s = 1;
-
-  for (int i = 0; i < n - 1; i++) {
-    p[i] = RAND_UNIFORM(s * 0.2, s * 0.8);
-    s -= p[i];
-  }
-
-  p[n - 1] = s;
-}
-
-void DiscreteDST::generate(int* v) {
-  int n = p.length();
-
-  double r = RAND_UNIFORM_01;
-  double s = 0;
-
-  for (int i = 0; i < N; i++) {
-    s += p[i];
-    if (s >= r) {
-      *v = i;
-      return;
-    }
-  }
-
-  *v = n - 1;
-}
-
-void DiscreteDST::end_accumulate() {
-  int n = p.length();
-
-  double s = 0;
-
-  for (int i = 0; i < N; i++)
-    s += acc_p[i];
-
-  if (s == 0)
-    s = -INFINITY;
-
-  for (int i = 0; i < N; i++)
-    p[i] = ACC_p[i] / s;
-}
-
-}; // namespace hmm
-}; // namespace mlpack

Deleted: mlpack/trunk/src/mlpack/methods/hmm/discreteDST.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/discreteDST.hpp	2011-11-12 19:16:25 UTC (rev 10256)
+++ mlpack/trunk/src/mlpack/methods/hmm/discreteDST.hpp	2011-11-12 19:17:10 UTC (rev 10257)
@@ -1,33 +0,0 @@
-#ifndef __MLPACK_METHODS_HMM_DISCRETE_DISTRIBUTION_HPP
-#define __MLPACK_METHODS_HMM_DISCRETE_DISTRIBUTION_HPP
-
-#include <mlpack/core.h>
-
-namespace mlpack {
-namespace hmm {
-
-class DiscreteDST {
- private:
-  arma::vec p;
-  arma::vec acc_p;
-
- public:
-  void Init(int n = 2);
-
-  void generate(int* v);
-
-  double get(int i) { return p[i]; }
-
-  void set(const arma::vec& p_) { p = p; }
-
-  void start_accumulate() { acc_p.zeros(); }
-
-  void accumulate(int i, double v) { acc_p[i] += v; }
-
-  void end_accumulate();
-};
-
-}; // namespace hmm
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_HMM_DISCRETE_DISTRIBUTION_HPP

Modified: mlpack/trunk/src/mlpack/methods/hmm/discreteHMM.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/discreteHMM.cpp	2011-11-12 19:16:25 UTC (rev 10256)
+++ mlpack/trunk/src/mlpack/methods/hmm/discreteHMM.cpp	2011-11-12 19:17:10 UTC (rev 10257)
@@ -326,7 +326,7 @@
   for (size_t i = 0; i < M; i++)
     bs(i, L - 1) = 1.0;
 
-  for (size_t t = L - 2; t >= 0; t--) {
+  for (size_t t = L - 2; t + 1 > 0; t--) {
     size_t e = (size_t) seq[t + 1];
     for (size_t i = 0; i < M; i++) {
       for (size_t j = 0; j < M; j++)
@@ -422,7 +422,7 @@
     }
 
   states[L - 1] = bestPtr;
-  for (size_t t = L - 2; t >= 0; t--)
+  for (size_t t = L - 2; t + 1 > 0; t--)
     states[t] = w((size_t) states[t + 1], t + 1);
 
   return bestVal;

Added: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp	2011-11-12 19:17:10 UTC (rev 10257)
@@ -0,0 +1,131 @@
+/**
+ * @file hmm.hpp
+ * @author Ryan Curtin
+ * @author Tran Quoc Long
+ *
+ * Definition of HMM class.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_HPP
+#define __MLPACK_METHODS_HMM_HMM_HPP
+
+#include <mlpack/core.h>
+
+namespace mlpack {
+namespace hmm {
+
+/**
+ * A class that represents a Hidden Markov Model.
+ */
+template<typename Distribution>
+class HMM
+{
+ private:
+  //! Transition probability matrix.
+  arma::mat transition;
+
+  //! Emission probability matrix (for each state).
+  arma::mat emission;
+
+ public:
+  /**
+   * Create the Hidden Markov Model with the given number of hidden states and
+   * the given number of emission states.
+   *
+   * @param states Number of states.
+   * @param emissions Number of possible emissions.
+   */
+  HMM(const size_t states, const size_t emissions);
+
+  /**
+   * Create the Hidden Markov Model with the given transition matrix and the
+   * given emission probability matrix.
+   *
+   * The transition matrix should be such that T(i, j) is the probability of
+   * transition to state i from state j.  The columns of the matrix should sum
+   * to 1.
+   *
+   * The emission matrix should be such that E(i, j) is the probability of
+   * emission i while in state j.  The columns of the matrix should sum to 1.
+   *
+   * @param transition Transition matrix.
+   * @param emission Emission probability matrix.
+   */
+  HMM(const arma::mat& transition, const arma::mat& emission);
+
+  /**
+   * Estimate the transition and emission matrices.
+   */
+  void EstimateModel(const std::vector<arma::vec>& data_seq);
+  void EstimateModel(const std::vector<arma::vec>& data_seq,
+                     const std::vector<arma::vec>& state_seq);
+
+  /**
+   * Generate a random data sequence of the given length.  The data sequence is
+   * stored in the data_sequence parameter, and the state sequence is stored in
+   * the state_sequence parameter.
+   *
+   * @param length Length of random sequence to generate.
+   * @param data_sequence Vector to store data in.
+   * @param state_sequence Vector to store states in.
+   */
+  void GenerateSequence(const size_t length,
+                        arma::vec& data_sequence,
+                        arma::vec& state_sequence) const;
+
+  /**
+   * Estimate the probabilities of each hidden state at each time step for each
+   * given data observation.
+   */
+  double Estimate(const arma::vec& data_seq,
+                  arma::mat& state_prob_mat,
+                  arma::mat& forward_prob_mat,
+                  arma::mat& backward_prob_mat,
+                  arma::vec& scale_vec) const;
+
+  /**
+   * Compute the log-likelihood of a sequence.
+   */
+  double LogLikelihood(const arma::vec& data_seq) const;
+
+  /**
+   * Compute the most probable hidden state sequence for a given data sequence.
+   * Needs a better name.
+   */
+  double Viterbi(const arma::vec& data_seq, arma::Col<size_t>& state_seq) const;
+
+  /**
+   * Return the transition matrix.
+   */
+  const arma::mat& Transition() const { return transition; }
+
+  /**
+   * Return a modifiable transition matrix reference.
+   */
+  arma::mat& Transition() { return transition; }
+
+  /**
+   * Return the emission probability matrix.
+   */
+  const arma::mat& Emission() const { return emission; }
+
+  /**
+   * Return a modifiable emission probability matrix reference.
+   */
+  arma::mat& Emission() { return emission; }
+
+ private:
+  // Helper functions.
+
+  void Forward(const arma::vec& data_seq, arma::vec& scales, arma::mat& fs)
+      const;
+  void Backward(const arma::vec& data_seq, const arma::vec& scales,
+                arma::mat& bs) const;
+};
+
+}; // namespace hmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "hmm_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	2011-11-12 19:17:10 UTC (rev 10257)
@@ -0,0 +1,267 @@
+/**
+ * @file hmm_impl.hpp
+ * @author Ryan Curtin
+ * @author Tran Quoc Long
+ *
+ * Implementation of HMM class.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
+
+// Just in case...
+#include "hmm.hpp"
+
+namespace mlpack {
+namespace hmm {
+
+template<typename Distribution>
+HMM<Distribution>::HMM(const size_t states, const size_t emissions) :
+    transition(arma::ones<arma::mat>(states, states) / (double) states),
+    emission(arma::ones<arma::mat>(emissions, states) / (double) emissions)
+{ /* nothing to do */ }
+
+template<typename Distribution>
+HMM<Distribution>::HMM(const arma::mat& transition, const arma::mat& emission) :
+    transition(transition),
+    emission(emission)
+{ /* nothing to do */ }
+
+// Generate the model if we only know the observations.
+template<typename Distribution>
+void HMM<Distribution>::EstimateModel(const std::vector<arma::vec>& data_seq)
+{
+  // We should allow a guess at the transition and emission matrices.
+
+  double loglik = 0;
+  double oldLoglik = 0;
+
+  // Maximum iterations?
+  size_t iterations = 1000;
+
+  // This should be the Baum-Welch algorithm (EM for HMM estimation). This
+  // follows the procedure outlined in Elliot, Aggoun, and Moore's book "Hidden
+  // Markov Models: Estimation and Control", pp. 36-40.
+  for (size_t iter = 0; iter < iterations; iter++)
+  {
+    // Clear new transition and emission matrices.
+    arma::mat newTransition(transition.n_rows, transition.n_cols);
+    arma::mat newEmission(emission.n_rows, emission.n_cols);
+    newTransition.zeros();
+    newEmission.zeros();
+
+    // Reset log likelihood.
+    loglik = 0;
+
+    // Loop over each sequence.
+    for (size_t seq = 0; seq < data_seq.size(); seq++)
+    {
+      arma::mat stateProb;
+      arma::mat forward;
+      arma::mat backward;
+      arma::vec scales;
+
+      // Add the log-likelihood of this sequence.  This is the E-step.
+      loglik += Estimate(data_seq[seq], stateProb, forward, backward, scales);
+
+      // Now re-estimate the parameters.  This is the M-step.
+      //   T_ij = sum_d ((1 / P(seq[d])) sum_t (f(i, t) T_ij E_i(seq[d][t]) b(i,
+      //           t + 1)))
+      //   E_ij = sum_d ((1 / P(seq[d])) sum_{t | seq[d][t] = j} f(i, t) b(i, t)
+      // We store the new estimates in a different matrix.
+      for (size_t t = 0; t < data_seq[seq].n_elem; t++)
+      {
+        for (size_t j = 0; j < transition.n_cols; j++)
+        {
+          if (t < data_seq[seq].n_elem - 1)
+          {
+            // Estimate of T_ij (probability of transition from state j to state
+            // i).  We postpone multiplication of the old T_ij until later.
+            for (size_t i = 0; i < transition.n_rows; i++)
+              newTransition(i, j) += forward(j, t) * backward(i, t + 1) *
+                  emission((size_t) data_seq[seq][t + 1], i) / scales[t + 1];
+          }
+
+          // Estimate of E_ij (probability of emission i while in state j).
+          newEmission((size_t) data_seq[seq][t], j) += stateProb(j, t);
+        }
+      }
+    }
+
+    // Assign the new matrices.  We use %= (element-wise multiplication) because
+    // every element of the new transition matrix must still be multiplied by
+    // the old elements (this is the multiplication we earlier postponed).
+    transition %= newTransition;
+    emission = newEmission;
+
+    // Now we normalize the transition matrices.
+    for (size_t i = 0; i < transition.n_cols; i++)
+      transition.col(i) /= accu(transition.col(i));
+
+    for (size_t i = 0; i < emission.n_cols; i++)
+      emission.col(i) /= accu(emission.col(i));
+
+    Log::Debug << "Iteration " << iter << ": log-likelihood " << loglik
+        << std::endl;
+
+    if (fabs(oldLoglik - loglik) < 1e-5)
+    {
+      Log::Debug << "Converged after " << iter << " iterations." << std::endl;
+      break;
+    }
+
+    oldLoglik = loglik;
+  }
+}
+
+// Generate the model.
+template<typename Distribution>
+void HMM<Distribution>::EstimateModel(const std::vector<arma::vec>& data_seq,
+                                      const std::vector<arma::vec>& state_seq)
+{
+
+}
+
+/**
+ * Estimate the probabilities of each hidden state at each time step for each
+ * given data observation.
+ */
+template<typename Distribution>
+double HMM<Distribution>::Estimate(const arma::vec& data_seq,
+                                   arma::mat& state_prob_mat,
+                                   arma::mat& forward_prob_mat,
+                                   arma::mat& backward_prob_mat,
+                                   arma::vec& scale_vec) const
+{
+  // First run the forward-backward algorithm.
+  Forward(data_seq, scale_vec, forward_prob_mat);
+  Backward(data_seq, scale_vec, backward_prob_mat);
+
+  // Now assemble the state probability matrix based on the forward and backward
+  // probabilities.
+  state_prob_mat = forward_prob_mat % backward_prob_mat;
+
+  // Finally assemble the log-likelihood and return it.
+  return accu(log(scale_vec));
+}
+
+/**
+ * Compute the most probable hidden state sequence for the given observation.
+ * Returns the log-likelihood of the most likely sequence.
+ */
+template<typename Distribution>
+double HMM<Distribution>::Viterbi(const arma::vec& dataSeq,
+                                  arma::Col<size_t>& stateSeq) const
+{
+  // This is an implementation of the Viterbi algorithm for finding the most
+  // probable sequence of states to produce the observed data sequence.  We
+  // don't use log-likelihoods to save that little bit of time, but we'll
+  // calculate the log-likelihood at the end of it all.
+  stateSeq.set_size(dataSeq.n_elem);
+  arma::mat stateProb(transition.n_rows, dataSeq.n_elem);
+
+  // The calculation of the first state is slightly different; the probability
+  // of the first state being state j is the maximum probability that the state
+  // came to be j from another state.  We can do that in one line.
+  stateProb.col(0) = transition.col(0) %
+      trans(emission.row((size_t) dataSeq[0]));
+
+  // Store the best first state.
+  arma::u32 index;
+  stateProb.unsafe_col(0).max(index);
+  stateSeq[0] = index;
+
+  for (size_t t = 1; t < dataSeq.n_elem; t++)
+  {
+    // Assemble the state probability for this element.
+    // Given that we are in state j, we state with the highest probability of
+    // being the previous state.
+    for (size_t j = 0; j < transition.n_rows; j++)
+    {
+      arma::vec prob = stateProb.col(t - 1) % trans(transition.row(j));
+      stateProb(j, t) = prob.max() * emission((size_t) dataSeq[t], j);
+    }
+
+    // Store the best state.
+    stateProb.unsafe_col(t).max(index);
+    stateSeq[t] = index;
+  }
+
+  return log(stateProb(stateSeq[dataSeq.n_elem - 1], dataSeq.n_elem - 1));
+}
+
+
+template<typename Distribution>
+void HMM<Distribution>::Forward(const arma::vec& dataSeq,
+                                arma::vec& scales,
+                                arma::mat& forwardProbabilities) const
+{
+  // Our goal is to calculate the forward probabilities:
+  //  P(X_k | o_{1:k}) for all possible states X_k, for each time point k.
+  forwardProbabilities.zeros(transition.n_rows, dataSeq.n_elem);
+  scales.zeros(dataSeq.n_elem);
+
+  // Starting state (at t = -1) is assumed to be state 0.  This is what MATLAB
+  // does in their hmmdecode() function, so we will emulate that behavior.
+  size_t obs = (size_t) dataSeq[0];
+  forwardProbabilities.col(0) = transition.col(0) % trans(emission.row(obs));
+  // Then normalize the column.
+  scales[0] = accu(forwardProbabilities.col(0));
+  forwardProbabilities.col(0) /= scales[0];
+
+  // Now compute the probabilities for each successive observation.
+  for (size_t t = 1; t < dataSeq.n_elem; t++)
+  {
+    obs = (size_t) dataSeq[t];
+
+    for (size_t j = 0; j < transition.n_rows; j++)
+    {
+      // The forward probability of state j at time t is the sum over all states
+      // of the probability of the previous state transitioning to the current
+      // state and emitting the given observation.
+      forwardProbabilities(j, t) = accu(forwardProbabilities.col(t - 1) %
+          trans(transition.row(j))) * emission(obs, j);
+    }
+
+    // Normalize probability.
+    scales[t] = accu(forwardProbabilities.col(t));
+    forwardProbabilities.col(t) /= scales[t];
+  }
+}
+
+template<typename Distribution>
+void HMM<Distribution>::Backward(const arma::vec& dataSeq,
+                                 const arma::vec& scales,
+                                 arma::mat& backwardProbabilities) const
+{
+  // Our goal is to calculate the backward probabilities:
+  //  P(X_k | o_{k + 1:T}) for all possible states X_k, for each time point k.
+  backwardProbabilities.zeros(transition.n_rows, dataSeq.n_elem);
+
+  // The last element probability is 1.
+  backwardProbabilities.col(dataSeq.n_elem - 1).fill(1);
+
+  // Now step backwards through all other observations.
+  for (size_t t = dataSeq.n_elem - 2; t + 1 > 0; t--)
+  {
+    // This will eventually need to depend on the distribution.
+    size_t obs = (size_t) dataSeq[t + 1];
+
+    for (size_t j = 0; j < transition.n_rows; j++)
+    {
+      // The backward probability of state j at time t is the sum over all state
+      // of the probability of the next state having been a transition from the
+      // current state multiplied by the probability of each of those states
+      // emitting the given observation.
+      backwardProbabilities(j, t) += accu(transition.col(j) %
+          backwardProbabilities.col(t + 1) % trans(emission.row(obs)));
+
+      // Normalize by the weights from the forward algorithm.
+      backwardProbabilities(j, t) /= scales[t + 1];
+    }
+  }
+}
+
+}; // namespace hmm
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list