[mlpack-git] master: Manually calculate starting state probability. (0c945f3)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 15 08:58:17 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2e33c5c42cfba6a31b733260c8e92086000e4d2c...0c945f3d36d92cc49dd2fde536530b40d3129d0a

>---------------------------------------------------------------

commit 0c945f3d36d92cc49dd2fde536530b40d3129d0a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Apr 15 08:57:59 2016 -0400

    Manually calculate starting state probability.


>---------------------------------------------------------------

0c945f3d36d92cc49dd2fde536530b40d3129d0a
 src/mlpack/tests/hmm_test.cpp | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/tests/hmm_test.cpp b/src/mlpack/tests/hmm_test.cpp
index 65ecf26..762090f 100644
--- a/src/mlpack/tests/hmm_test.cpp
+++ b/src/mlpack/tests/hmm_test.cpp
@@ -236,6 +236,7 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
   std::vector<arma::mat> observations;
   size_t obsNum = 250; // Number of observations.
   size_t obsLen = 500; // Number of elements in each observation.
+  size_t stateZeroStarts = 0; // Number of times we start in state 0.
   for (size_t i = 0; i < obsNum; i++)
   {
     arma::mat observation(1, obsLen);
@@ -249,9 +250,15 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
       double r = math::Random();
 
       if (r <= 0.5)
+      {
+        if (obs == 0)
+          ++stateZeroStarts;
         state = 0;
+      }
       else
+      {
         state = 1;
+      }
 
       // Now set the observation.
       r = math::Random();
@@ -281,9 +288,12 @@ BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
 
   hmm.Train(observations);
 
+  // Calculate true probability of class 0 at the start.
+  double prob = double(stateZeroStarts) / observations.size();
+
   // Only require 2.5% tolerance, because this is a little fuzzier.
-  BOOST_REQUIRE_CLOSE(hmm.Initial()[0], 0.5, 2.5);
-  BOOST_REQUIRE_CLOSE(hmm.Initial()[1], 0.5, 2.5);
+  BOOST_REQUIRE_CLOSE(hmm.Initial()[0], prob, 2.5);
+  BOOST_REQUIRE_CLOSE(hmm.Initial()[1], 1.0 - prob, 2.5);
 
   BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 2.5);
   BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 0), 0.5, 2.5);




More information about the mlpack-git mailing list