[mlpack-git] master, mlpack-1.0.x: fixed ticket #316 (hmm::Predict() in hmm_impl.hpp) (df1f46e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:41:40 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit df1f46e6f9f747aeeee859a7d9a0f2f8da18d784
Author: michaelfox99 <michaelfox99 at gmail.com>
Date:   Sat Feb 1 23:44:45 2014 +0000

    fixed ticket #316 (hmm::Predict() in hmm_impl.hpp)


>---------------------------------------------------------------

df1f46e6f9f747aeeee859a7d9a0f2f8da18d784
 src/mlpack/methods/hmm/hmm_impl.hpp | 52 ++++++++++++++++++++++++-------------
 1 file changed, 34 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp
index c2ccd4b..877b58f 100644
--- a/src/mlpack/methods/hmm/hmm_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_impl.hpp
@@ -4,6 +4,21 @@
  * @author Tran Quoc Long
  *
  * Implementation of HMM class.
+ *
+ * This file is part of MLPACK 1.0.7.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK.  If not, see <http://www.gnu.org/licenses/>.
  */
 #ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
 #define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
@@ -358,6 +373,7 @@ double HMM<Distribution>::Predict(const arma::mat& dataSeq,
   // calculate the log-likelihood at the end of it all.
   stateSeq.set_size(dataSeq.n_cols);
   arma::mat logStateProb(transition.n_rows, dataSeq.n_cols);
+  arma::mat stateSeqBack(transition.n_rows, dataSeq.n_cols);
 
   // Store the logs of the transposed transition matrix.  This is because we
   // will be using the rows of the transition matrix.
@@ -368,30 +384,34 @@ double HMM<Distribution>::Predict(const arma::mat& dataSeq,
   // came to be j from another state.
   logStateProb.col(0).zeros();
   for (size_t state = 0; state < transition.n_rows; state++)
+  {
     logStateProb[state] = log(transition(state, 0) *
         emission[state].Probability(dataSeq.unsafe_col(0)));
+    stateSeqBack[state] = state;
+  }
 
   // Store the best first state.
   arma::uword index;
   logStateProb.unsafe_col(0).max(index);
-  stateSeq[0] = index;
-
   for (size_t t = 1; t < dataSeq.n_cols; t++)
   {
     // Assemble the state probability for this element.
-    // Given that we are in state j, we state with the highest probability of
-    // being the previous state.
+    // Given that we are in state j, we use state with the highest probability
+    // of being the previous state.
     for (size_t j = 0; j < transition.n_rows; j++)
     {
       arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
       logStateProb(j, t) = prob.max() +
           log(emission[j].Probability(dataSeq.unsafe_col(t)));
+      prob.max(index);
+      stateSeqBack(j, t) = index;
     }
-
-    // Store the best state.
-    logStateProb.unsafe_col(t).max(index);
-    stateSeq[t] = index;
   }
+  // Backtrack to find most probable state sequence
+  logStateProb.unsafe_col(dataSeq.n_cols-1).max(index);
+  stateSeq[dataSeq.n_cols-1] = index;
+  for (size_t t = 2; t <= dataSeq.n_cols; t++)
+    stateSeq[dataSeq.n_cols-t] = stateSeqBack(stateSeq[dataSeq.n_cols-t+1], dataSeq.n_cols-t+1);
     
   return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
 }
@@ -430,10 +450,9 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
     forwardProb(state, 0) = transition(state, 0) *
         emission[state].Probability(dataSeq.unsafe_col(0));
 
-  // Then normalize the column, but only if the scale is not 0.
+  // Then normalize the column.
   scales[0] = accu(forwardProb.col(0));
-  if (scales[0] != 0.0)
-    forwardProb.col(0) /= scales[0];
+  forwardProb.col(0) /= scales[0];
 
   // Now compute the probabilities for each successive observation.
   for (size_t t = 1; t < dataSeq.n_cols; t++)
@@ -448,10 +467,9 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
           emission[j].Probability(dataSeq.unsafe_col(t));
     }
 
-    // Normalize probability, but only if the scale is not 0.
+    // Normalize probability.
     scales[t] = accu(forwardProb.col(t));
-    if (scales[t] != 0.0)
-      forwardProb.col(t) /= scales[t];
+    forwardProb.col(t) /= scales[t];
   }
 }
 
@@ -480,10 +498,8 @@ void HMM<Distribution>::Backward(const arma::mat& dataSeq,
         backwardProb(j, t) += transition(state, j) * backwardProb(state, t + 1)
             * emission[state].Probability(dataSeq.unsafe_col(t + 1));
 
-      // Normalize by the weights from the forward algorithm, if the scale is
-      // not 0.
-      if (scales[t + 1] != 0.0)
-        backwardProb(j, t) /= scales[t + 1];
+      // Normalize by the weights from the forward algorithm.
+      backwardProb(j, t) /= scales[t + 1];
     }
   }
 }



More information about the mlpack-git mailing list