[mlpack-svn] r16199 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Feb 1 18:44:45 EST 2014
Author: michaelfox99
Date: Sat Feb 1 18:44:45 2014
New Revision: 16199
Log:
fixed ticket #316 (hmm::Predict() in hmm_impl.hpp)
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp Sat Feb 1 18:44:45 2014
@@ -4,6 +4,21 @@
* @author Tran Quoc Long
*
* Implementation of HMM class.
+ *
+ * This file is part of MLPACK 1.0.7.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
#define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
@@ -308,7 +323,7 @@
// Set vectors to the right size.
stateSequence.set_size(length);
dataSequence.set_size(dimensionality, length);
-
+
// Set start state (default is 0).
stateSequence[0] = startState;
@@ -358,6 +373,7 @@
// calculate the log-likelihood at the end of it all.
stateSeq.set_size(dataSeq.n_cols);
arma::mat logStateProb(transition.n_rows, dataSeq.n_cols);
+ arma::mat stateSeqBack(transition.n_rows, dataSeq.n_cols);
// Store the logs of the transposed transition matrix. This is because we
// will be using the rows of the transition matrix.
@@ -368,31 +384,35 @@
// came to be j from another state.
logStateProb.col(0).zeros();
for (size_t state = 0; state < transition.n_rows; state++)
+ {
logStateProb[state] = log(transition(state, 0) *
emission[state].Probability(dataSeq.unsafe_col(0)));
+ stateSeqBack[state] = state;
+ }
// Store the best first state.
arma::uword index;
logStateProb.unsafe_col(0).max(index);
- stateSeq[0] = index;
-
for (size_t t = 1; t < dataSeq.n_cols; t++)
{
// Assemble the state probability for this element.
- // Given that we are in state j, we state with the highest probability of
- // being the previous state.
+ // Given that we are in state j, we use state with the highest probability
+ // of being the previous state.
for (size_t j = 0; j < transition.n_rows; j++)
{
arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
logStateProb(j, t) = prob.max() +
log(emission[j].Probability(dataSeq.unsafe_col(t)));
+ prob.max(index);
+ stateSeqBack(j, t) = index;
}
-
- // Store the best state.
- logStateProb.unsafe_col(t).max(index);
- stateSeq[t] = index;
}
-
+ // Backtrack to find most probable state sequence
+ logStateProb.unsafe_col(dataSeq.n_cols-1).max(index);
+ stateSeq[dataSeq.n_cols-1] = index;
+ for (size_t t = 2; t <= dataSeq.n_cols; t++)
+ stateSeq[dataSeq.n_cols-t] = stateSeqBack(stateSeq[dataSeq.n_cols-t+1], dataSeq.n_cols-t+1);
+
return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
}
@@ -430,10 +450,9 @@
forwardProb(state, 0) = transition(state, 0) *
emission[state].Probability(dataSeq.unsafe_col(0));
- // Then normalize the column, but only if the scale is not 0.
+ // Then normalize the column.
scales[0] = accu(forwardProb.col(0));
- if (scales[0] != 0.0)
- forwardProb.col(0) /= scales[0];
+ forwardProb.col(0) /= scales[0];
// Now compute the probabilities for each successive observation.
for (size_t t = 1; t < dataSeq.n_cols; t++)
@@ -448,10 +467,9 @@
emission[j].Probability(dataSeq.unsafe_col(t));
}
- // Normalize probability, but only if the scale is not 0.
+ // Normalize probability.
scales[t] = accu(forwardProb.col(t));
- if (scales[t] != 0.0)
- forwardProb.col(t) /= scales[t];
+ forwardProb.col(t) /= scales[t];
}
}
@@ -480,10 +498,8 @@
backwardProb(j, t) += transition(state, j) * backwardProb(state, t + 1)
* emission[state].Probability(dataSeq.unsafe_col(t + 1));
- // Normalize by the weights from the forward algorithm, if the scale is
- // not 0.
- if (scales[t + 1] != 0.0)
- backwardProb(j, t) /= scales[t + 1];
+ // Normalize by the weights from the forward algorithm.
+ backwardProb(j, t) /= scales[t + 1];
}
}
}
More information about the mlpack-svn
mailing list