[mlpack-svn] r10262 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Nov 12 19:30:10 EST 2011


Author: rcurtin
Date: 2011-11-12 19:30:10 -0500 (Sat, 12 Nov 2011)
New Revision: 10262

Modified:
   mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Add test for supervised transition and emission matrix estimation.


Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-13 00:29:57 UTC (rev 10261)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-13 00:30:10 UTC (rev 10262)
@@ -40,7 +40,7 @@
   arma::vec observation("0 0 1 0 0");
 
   arma::Col<size_t> states;
-  hmm.Viterbi(observation, states);
+  hmm.Predict(observation, states);
 
   // Check each state.
   BOOST_REQUIRE_EQUAL(states[0], 0); // Rain.
@@ -73,7 +73,7 @@
   arma::vec observation("2 2 1 0 1 3 2 0 0");
 
   arma::Col<size_t> states;
-  hmm.Viterbi(observation, states);
+  hmm.Predict(observation, states);
 
   // Most probable path is HHHLLLLLL.
   BOOST_REQUIRE_EQUAL(states[0], 1);
@@ -151,7 +151,7 @@
   observations.push_back("0 0 0 0 0 0 0 0 0 0 0 0");
   observations.push_back("0 0 0 0 0 0 0 0 0 0");
 
-  hmm.EstimateModel(observations);
+  hmm.Train(observations);
 
   BOOST_REQUIRE_CLOSE(hmm.Emission()(0, 0), 1.0, 1e-5);
   BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
@@ -189,7 +189,7 @@
   observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
   observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
 
-  hmm.EstimateModel(observations);
+  hmm.Train(observations);
 
   BOOST_REQUIRE_CLOSE(hmm.Emission()(0, 0), 0.5, 1e-5);
   BOOST_REQUIRE_CLOSE(hmm.Emission()(1, 0), 0.5, 1e-5);
@@ -272,7 +272,7 @@
     out.col(i) = observations[i];
   data::Save("out.csv", out);
 
-  hmm.EstimateModel(observations);
+  hmm.Train(observations);
 
   // Only require 0.75% tolerance, because this is a little fuzzier.
   BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 0.75);
@@ -290,4 +290,100 @@
   BOOST_REQUIRE_CLOSE(hmm.Emission()(3, 1), 0.8, 0.75);
 }
 
+BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
+{
+  // Generate a random Markov model with 3 hidden states and 6 observations.
+  arma::mat transition;
+  arma::mat emission;
+
+  transition.randu(3, 3);
+  emission.randu(6, 3);
+
+  // Normalize so they are correct transition and emission matrices.
+  for (size_t col = 0; col < 3; col++)
+  {
+    transition.col(col) /= accu(transition.col(col));
+    emission.col(col) /= accu(emission.col(col));
+  }
+
+  // Now generate sequences.
+  size_t obsNum = 250;
+  size_t obsLen = 800;
+
+  std::vector<arma::vec> observations(obsNum);
+  std::vector<arma::Col<size_t> > states(obsNum);
+
+  for (size_t n = 0; n < obsNum; n++)
+  {
+    observations[n].set_size(obsLen);
+    states[n].set_size(obsLen);
+
+    // Random starting state.
+    states[n][0] = rand() % 3;
+
+    // Random starting observation.
+    double obs = (double) rand() / (double) RAND_MAX;
+    double sumProb = 0;
+    for (size_t em = 0; em < 6; em++)
+    {
+      sumProb += emission(em, states[n][0]);
+      if (sumProb > obs)
+      {
+        observations[n][0] = em;
+        break;
+      }
+    }
+
+    // Now the rest of the observations.
+    for (size_t t = 1; t < obsLen; t++)
+    {
+      // Choose random numbers for state transition and for emission transition.
+      double obs = (double) rand() / (double) RAND_MAX;
+      double state = (double) rand() / (double) RAND_MAX;
+
+      // Decide next state.
+      double sumProb = 0;
+      for (size_t st = 0; st < 3; st++)
+      {
+        sumProb += transition(st, states[n][t - 1]);
+        if (sumProb > state)
+        {
+          states[n][t] = st;
+          break;
+        }
+      }
+
+      // Decide observation.
+      sumProb = 0;
+      for (size_t em = 0; em < 6; em++)
+      {
+        sumProb += emission(em, states[n][t]);
+        if (sumProb > obs)
+        {
+          observations[n][t] = em;
+          break;
+        }
+      }
+    }
+  }
+
+  // Now that our data is generated, we give the HMM the labeled data to train
+  // on.
+  HMM<int> hmm(3, 6);
+
+  hmm.Train(observations, states);
+
+  // We can't use % tolerance here because percent error increases as the actual
+  // value gets very small.  So, instead, we just ensure that every value is no
+  // more than 0.004 away from the actual value.
+  for (size_t row = 0; row < hmm.Transition().n_rows; row++)
+    for (size_t col = 0; col < hmm.Transition().n_cols; col++)
+      BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
+          0.004);
+
+  for (size_t row = 0; row < hmm.Emission().n_rows; row++)
+    for (size_t col = 0; col < hmm.Emission().n_cols; col++)
+      BOOST_REQUIRE_SMALL(hmm.Emission()(row, col) - emission(row, col), 0.004);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list