[mlpack-git] master: Better handling of NaNs. (2744628)

gitdub at mlpack.org gitdub at mlpack.org
Sun Jun 5 11:08:09 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit 274462840d9a62193ed796de75aa0d1a64550236
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon May 16 14:58:05 2016 -0400

    Better handling of NaNs.


>---------------------------------------------------------------

274462840d9a62193ed796de75aa0d1a64550236
 src/mlpack/methods/hmm/hmm_impl.hpp | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp
index b567f0f..2bfc4cb 100644
--- a/src/mlpack/methods/hmm/hmm_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_impl.hpp
@@ -173,7 +173,13 @@ void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq)
 
     // Now we normalize the transition matrix.
     for (size_t i = 0; i < transition.n_cols; i++)
-      transition.col(i) /= accu(transition.col(i));
+    {
+      const double sum = accu(transition.col(i));
+      if (sum > 0.0)
+        transition.col(i) /= sum;
+      else
+        transition.col(i).fill(1.0 / (double) transition.n_rows);
+    }
 
     // Now estimate emission probabilities.
     for (size_t state = 0; state < transition.n_cols; state++)
@@ -513,7 +519,8 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
 
   // Then normalize the column.
   scales[0] = accu(forwardProb.col(0));
-  forwardProb.col(0) /= scales[0];
+  if (scales[0] > 0.0)
+    forwardProb.col(0) /= scales[0];
 
   // Now compute the probabilities for each successive observation.
   for (size_t t = 1; t < dataSeq.n_cols; t++)
@@ -530,7 +537,8 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
 
     // Normalize probability.
     scales[t] = accu(forwardProb.col(t));
-    forwardProb.col(t) /= scales[t];
+    if (scales[t] > 0.0)
+      forwardProb.col(t) /= scales[t];
   }
 }
 
@@ -560,7 +568,8 @@ void HMM<Distribution>::Backward(const arma::mat& dataSeq,
             * emission[state].Probability(dataSeq.unsafe_col(t + 1));
 
       // Normalize by the weights from the forward algorithm.
-      backwardProb(j, t) /= scales[t + 1];
+      if (scales[t + 1] > 0.0)
+        backwardProb(j, t) /= scales[t + 1];
     }
   }
 }




More information about the mlpack-git mailing list