[mlpack-git] master: Better handling of NaNs. (2744628)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Jun 5 11:08:09 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit 274462840d9a62193ed796de75aa0d1a64550236
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon May 16 14:58:05 2016 -0400
Better handling of NaNs.
>---------------------------------------------------------------
274462840d9a62193ed796de75aa0d1a64550236
src/mlpack/methods/hmm/hmm_impl.hpp | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp
index b567f0f..2bfc4cb 100644
--- a/src/mlpack/methods/hmm/hmm_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_impl.hpp
@@ -173,7 +173,13 @@ void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq)
// Now we normalize the transition matrix.
for (size_t i = 0; i < transition.n_cols; i++)
- transition.col(i) /= accu(transition.col(i));
+ {
+ const double sum = accu(transition.col(i));
+ if (sum > 0.0)
+ transition.col(i) /= sum;
+ else
+ transition.col(i).fill(1.0 / (double) transition.n_rows);
+ }
// Now estimate emission probabilities.
for (size_t state = 0; state < transition.n_cols; state++)
@@ -513,7 +519,8 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
// Then normalize the column.
scales[0] = accu(forwardProb.col(0));
- forwardProb.col(0) /= scales[0];
+ if (scales[0] > 0.0)
+ forwardProb.col(0) /= scales[0];
// Now compute the probabilities for each successive observation.
for (size_t t = 1; t < dataSeq.n_cols; t++)
@@ -530,7 +537,8 @@ void HMM<Distribution>::Forward(const arma::mat& dataSeq,
// Normalize probability.
scales[t] = accu(forwardProb.col(t));
- forwardProb.col(t) /= scales[t];
+ if (scales[t] > 0.0)
+ forwardProb.col(t) /= scales[t];
}
}
@@ -560,7 +568,8 @@ void HMM<Distribution>::Backward(const arma::mat& dataSeq,
* emission[state].Probability(dataSeq.unsafe_col(t + 1));
// Normalize by the weights from the forward algorithm.
- backwardProb(j, t) /= scales[t + 1];
+ if (scales[t + 1] > 0.0)
+ backwardProb(j, t) /= scales[t + 1];
}
}
}
More information about the mlpack-git
mailing list