[mlpack-svn] r10337 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 21 10:48:22 EST 2011


Author: rcurtin
Date: 2011-11-21 10:48:21 -0500 (Mon, 21 Nov 2011)
New Revision: 10337

Modified:
   mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
#162: use log-likelihoods instead of direct probabilities because direct
probabilities get too small.


Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	2011-11-21 04:54:58 UTC (rev 10336)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	2011-11-21 15:48:21 UTC (rev 10337)
@@ -283,19 +283,23 @@
   // don't use log-likelihoods to save that little bit of time, but we'll
   // calculate the log-likelihood at the end of it all.
   stateSeq.resize(dataSeq.size());
-  arma::mat stateProb(transition.n_rows, dataSeq.size());
+  arma::mat logStateProb(transition.n_rows, dataSeq.size());
 
+  // Store the logs of the transposed transition matrix.  This is because we
+  // will be using the rows of the transition matrix.
+  arma::mat logTrans(log(trans(transition)));
+
   // The calculation of the first state is slightly different; the probability
   // of the first state being state j is the maximum probability that the state
   // came to be j from another state.
-  stateProb.col(0).zeros();
+  logStateProb.col(0).zeros();
   for (size_t state = 0; state < transition.n_rows; state++)
-    stateProb[state] = transition(state, 0) *
-        emission[state].Probability(dataSeq[0]);
+    logStateProb[state] = log(transition(state, 0) *
+        emission[state].Probability(dataSeq[0]));
 
   // Store the best first state.
   arma::u32 index;
-  stateProb.unsafe_col(0).max(index);
+  logStateProb.unsafe_col(0).max(index);
   stateSeq[0] = index;
 
   for (size_t t = 1; t < dataSeq.size(); t++)
@@ -305,16 +309,17 @@
     // being the previous state.
     for (size_t j = 0; j < transition.n_rows; j++)
     {
-      arma::vec prob = stateProb.col(t - 1) % trans(transition.row(j));
-      stateProb(j, t) = prob.max() * emission[j].Probability(dataSeq[t]);
+      arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
+      logStateProb(j, t) = prob.max() +
+          log(emission[j].Probability(dataSeq[t]));
     }
 
     // Store the best state.
-    stateProb.unsafe_col(t).max(index);
+    logStateProb.unsafe_col(t).max(index);
     stateSeq[t] = index;
   }
 
-  return log(stateProb(stateSeq[dataSeq.size() - 1], dataSeq.size() - 1));
+  return logStateProb(stateSeq[dataSeq.size() - 1], dataSeq.size() - 1);
 }
 
 /**




More information about the mlpack-svn mailing list