[mlpack-svn] r10337 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 21 10:48:22 EST 2011
Author: rcurtin
Date: 2011-11-21 10:48:21 -0500 (Mon, 21 Nov 2011)
New Revision: 10337
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
#162: use log-likelihoods instead of direct probabilities because direct
probabilities get too small.
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-21 04:54:58 UTC (rev 10336)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-21 15:48:21 UTC (rev 10337)
@@ -283,19 +283,23 @@
// don't use log-likelihoods to save that little bit of time, but we'll
// calculate the log-likelihood at the end of it all.
stateSeq.resize(dataSeq.size());
- arma::mat stateProb(transition.n_rows, dataSeq.size());
+ arma::mat logStateProb(transition.n_rows, dataSeq.size());
+ // Store the logs of the transposed transition matrix. This is because we
+ // will be using the rows of the transition matrix.
+ arma::mat logTrans(log(trans(transition)));
+
// The calculation of the first state is slightly different; the probability
// of the first state being state j is the maximum probability that the state
// came to be j from another state.
- stateProb.col(0).zeros();
+ logStateProb.col(0).zeros();
for (size_t state = 0; state < transition.n_rows; state++)
- stateProb[state] = transition(state, 0) *
- emission[state].Probability(dataSeq[0]);
+ logStateProb[state] = log(transition(state, 0) *
+ emission[state].Probability(dataSeq[0]));
// Store the best first state.
arma::u32 index;
- stateProb.unsafe_col(0).max(index);
+ logStateProb.unsafe_col(0).max(index);
stateSeq[0] = index;
for (size_t t = 1; t < dataSeq.size(); t++)
@@ -305,16 +309,17 @@
// being the previous state.
for (size_t j = 0; j < transition.n_rows; j++)
{
- arma::vec prob = stateProb.col(t - 1) % trans(transition.row(j));
- stateProb(j, t) = prob.max() * emission[j].Probability(dataSeq[t]);
+ arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
+ logStateProb(j, t) = prob.max() +
+ log(emission[j].Probability(dataSeq[t]));
}
// Store the best state.
- stateProb.unsafe_col(t).max(index);
+ logStateProb.unsafe_col(t).max(index);
stateSeq[t] = index;
}
- return log(stateProb(stateSeq[dataSeq.size() - 1], dataSeq.size() - 1));
+ return logStateProb(stateSeq[dataSeq.size() - 1], dataSeq.size() - 1);
}
/**
More information about the mlpack-svn
mailing list