[mlpack-git] master, mlpack-1.0.x: fix forHMMTest/SimpleDiscreteHMMTestViterbi (dd58e5a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:46 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit dd58e5ab89240eff57174e417045b16728dc4b4c
Author: michaelfox99 <michaelfox99 at gmail.com>
Date: Mon Feb 10 03:50:53 2014 +0000
fix forHMMTest/SimpleDiscreteHMMTestViterbi
>---------------------------------------------------------------
dd58e5ab89240eff57174e417045b16728dc4b4c
src/mlpack/methods/hmm/hmm_impl.hpp | 8 +++----
src/mlpack/tests/hmm_test.cpp | 44 ++++++++++++++++++++++++-------------
2 files changed, 32 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp
index b45e503..5ee9bb2 100644
--- a/src/mlpack/methods/hmm/hmm_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_impl.hpp
@@ -370,14 +370,13 @@ double HMM<Distribution>::Predict(const arma::mat& dataSeq,
logStateProb.col(0).zeros();
for (size_t state = 0; state < transition.n_rows; state++)
{
- logStateProb[state] = log(transition(state, 0) *
+ logStateProb[state] = log(transition.unsafe_col(state).max() *
emission[state].Probability(dataSeq.unsafe_col(0)));
stateSeqBack[state] = state;
}
// Store the best first state.
arma::uword index;
- logStateProb.unsafe_col(0).max(index);
for (size_t t = 1; t < dataSeq.n_cols; t++)
{
// Assemble the state probability for this element.
@@ -386,10 +385,9 @@ double HMM<Distribution>::Predict(const arma::mat& dataSeq,
for (size_t j = 0; j < transition.n_rows; j++)
{
arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
- logStateProb(j, t) = prob.max() +
+ logStateProb(j, t) = prob.max(index) +
log(emission[j].Probability(dataSeq.unsafe_col(t)));
- prob.max(index);
- stateSeqBack(j, t) = index;
+ stateSeqBack(j, t) = index;
}
}
// Backtrack to find most probable state sequence
diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp
index eea6894..878c7ab 100644
--- a/src/mlpack/tests/hmm_test.cpp
+++ b/src/mlpack/tests/hmm_test.cpp
@@ -2,6 +2,21 @@
* @file hmm_test.cpp
*
* Test file for HMMs.
+ *
+ * This file is part of MLPACK 1.0.7.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/hmm/hmm.hpp>
@@ -35,8 +50,8 @@ BOOST_AUTO_TEST_CASE(SimpleDiscreteHMMTestViterbi)
// [0.1 0.8]] no umbrella
arma::mat transition("0.7 0.3; 0.3 0.7");
std::vector<DiscreteDistribution> emission(2);
- emission[0] = DiscreteDistribution("0.9 0.2");
- emission[1] = DiscreteDistribution("0.1 0.8");
+ emission[0] = DiscreteDistribution("0.9 0.1");
+ emission[1] = DiscreteDistribution("0.2 0.8");
HMM<DiscreteDistribution> hmm(transition, emission);
@@ -273,15 +288,14 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 1), 0.5, 2.5);
BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 1), 0.5, 2.5);
- // Widened to 3% tolerance.
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 3.0);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 2.5);
}
BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
@@ -351,11 +365,11 @@ BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
// We can't use % tolerance here because percent error increases as the actual
// value gets very small. So, instead, we just ensure that every value is no
- // more than 0.015 away from the actual value.
+ // more than 0.009 away from the actual value.
for (size_t row = 0; row < hmm.Transition().n_rows; row++)
for (size_t col = 0; col < hmm.Transition().n_cols; col++)
BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
- 0.015);
+ 0.009);
for (size_t col = 0; col < hmm.Emission().size(); col++)
{
@@ -365,7 +379,7 @@ BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
arma::vec obs(1);
obs[0] = row;
BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
- emission[col].Probability(obs), 0.015);
+ emission[col].Probability(obs), 0.009);
}
}
}
@@ -709,7 +723,7 @@ BOOST_AUTO_TEST_CASE(GaussianHMMGenerateTest)
for (size_t row = 0; row < 3; row++)
for (size_t col = 0; col < 3; col++)
BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - hmm2.Transition()(row,
- col), 0.032);
+ col), 0.03);
// Check that each Gaussian is the same.
for (size_t em = 0; em < 3; em++)
More information about the mlpack-git
mailing list