[mlpack-svn] r16199 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Feb 1 18:44:45 EST 2014


Author: michaelfox99
Date: Sat Feb  1 18:44:45 2014
New Revision: 16199

Log:
fixed ticket #316 (hmm::Predict() in hmm_impl.hpp)

Modified:
   mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	Sat Feb  1 18:44:45 2014
@@ -4,6 +4,21 @@
  * @author Tran Quoc Long
  *
  * Implementation of HMM class.
+ *
+ * This file is part of MLPACK 1.0.7.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
  */
 #ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
 #define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
@@ -308,7 +323,7 @@
   // Set vectors to the right size.
   stateSequence.set_size(length);
   dataSequence.set_size(dimensionality, length);
-
+    
   // Set start state (default is 0).
   stateSequence[0] = startState;
 
@@ -358,6 +373,7 @@
   // calculate the log-likelihood at the end of it all.
   stateSeq.set_size(dataSeq.n_cols);
   arma::mat logStateProb(transition.n_rows, dataSeq.n_cols);
+  arma::mat stateSeqBack(transition.n_rows, dataSeq.n_cols);
 
   // Store the logs of the transposed transition matrix.  This is because we
   // will be using the rows of the transition matrix.
@@ -368,31 +384,35 @@
   // came to be j from another state.
   logStateProb.col(0).zeros();
   for (size_t state = 0; state < transition.n_rows; state++)
+  {
     logStateProb[state] = log(transition(state, 0) *
         emission[state].Probability(dataSeq.unsafe_col(0)));
+    stateSeqBack[state] = state;
+  }
 
   // Store the best first state.
   arma::uword index;
   logStateProb.unsafe_col(0).max(index);
-  stateSeq[0] = index;
-
   for (size_t t = 1; t < dataSeq.n_cols; t++)
   {
     // Assemble the state probability for this element.
-    // Given that we are in state j, we state with the highest probability of
-    // being the previous state.
+    // Given that we are in state j, we use state with the highest probability
+    // of being the previous state.
     for (size_t j = 0; j < transition.n_rows; j++)
     {
       arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
       logStateProb(j, t) = prob.max() +
           log(emission[j].Probability(dataSeq.unsafe_col(t)));
+      prob.max(index);
+      stateSeqBack(j, t) = index;
     }
-
-    // Store the best state.
-    logStateProb.unsafe_col(t).max(index);
-    stateSeq[t] = index;
   }
-
+  // Backtrack to find most probable state sequence
+  logStateProb.unsafe_col(dataSeq.n_cols-1).max(index);
+  stateSeq[dataSeq.n_cols-1] = index;
+  for (size_t t = 2; t <= dataSeq.n_cols; t++)
+    stateSeq[dataSeq.n_cols-t] = stateSeqBack(stateSeq[dataSeq.n_cols-t+1], dataSeq.n_cols-t+1);
+    
   return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
 }
 
@@ -430,10 +450,9 @@
     forwardProb(state, 0) = transition(state, 0) *
         emission[state].Probability(dataSeq.unsafe_col(0));
 
-  // Then normalize the column, but only if the scale is not 0.
+  // Then normalize the column.
   scales[0] = accu(forwardProb.col(0));
-  if (scales[0] != 0.0)
-    forwardProb.col(0) /= scales[0];
+  forwardProb.col(0) /= scales[0];
 
   // Now compute the probabilities for each successive observation.
   for (size_t t = 1; t < dataSeq.n_cols; t++)
@@ -448,10 +467,9 @@
           emission[j].Probability(dataSeq.unsafe_col(t));
     }
 
-    // Normalize probability, but only if the scale is not 0.
+    // Normalize probability.
     scales[t] = accu(forwardProb.col(t));
-    if (scales[t] != 0.0)
-      forwardProb.col(t) /= scales[t];
+    forwardProb.col(t) /= scales[t];
   }
 }
 
@@ -480,10 +498,8 @@
         backwardProb(j, t) += transition(state, j) * backwardProb(state, t + 1)
             * emission[state].Probability(dataSeq.unsafe_col(t + 1));
 
-      // Normalize by the weights from the forward algorithm, if the scale is
-      // not 0.
-      if (scales[t + 1] != 0.0)
-        backwardProb(j, t) /= scales[t + 1];
+      // Normalize by the weights from the forward algorithm.
+      backwardProb(j, t) /= scales[t + 1];
     }
   }
 }



More information about the mlpack-svn mailing list