[mlpack-git] master: Manually calculate starting state probability. (0c945f3)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Apr 15 08:58:17 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/2e33c5c42cfba6a31b733260c8e92086000e4d2c...0c945f3d36d92cc49dd2fde536530b40d3129d0a
>---------------------------------------------------------------
commit 0c945f3d36d92cc49dd2fde536530b40d3129d0a
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Apr 15 08:57:59 2016 -0400
Manually calculate starting state probability.
>---------------------------------------------------------------
0c945f3d36d92cc49dd2fde536530b40d3129d0a
src/mlpack/tests/hmm_test.cpp | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp
index 65ecf26..762090f 100644
--- a/src/mlpack/tests/hmm_test.cpp
+++ b/src/mlpack/tests/hmm_test.cpp
@@ -236,6 +236,7 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
std::vector<arma::mat> observations;
size_t obsNum = 250; // Number of observations.
size_t obsLen = 500; // Number of elements in each observation.
+ size_t stateZeroStarts = 0; // Number of times we start in state 0.
for (size_t i = 0; i < obsNum; i++)
{
arma::mat observation(1, obsLen);
@@ -249,9 +250,15 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
double r = math::Random();
if (r <= 0.5)
+ {
+ if (obs == 0)
+ ++stateZeroStarts;
state = 0;
+ }
else
+ {
state = 1;
+ }
// Now set the observation.
r = math::Random();
@@ -281,9 +288,12 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
hmm.Train(observations);
+ // Calculate true probability of class 0 at the start.
+ double prob = double(stateZeroStarts) / observations.size();
+
// Only require 2.5% tolerance, because this is a little fuzzier.
- BOOST_REQUIRE_CLOSE(hmm.Initial()[0], 0.5, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Initial()[1], 0.5, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Initial()[0], prob, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Initial()[1], 1.0 - prob, 2.5);
BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 2.5);
BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 0), 0.5, 2.5);
More information about the mlpack-git
mailing list