[mlpack-svn] r10262 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Nov 12 19:30:10 EST 2011
Author: rcurtin
Date: 2011-11-12 19:30:10 -0500 (Sat, 12 Nov 2011)
New Revision: 10262
Modified:
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Add test for supervised transition and emission matrix estimation.
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-13 00:29:57 UTC (rev 10261)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-13 00:30:10 UTC (rev 10262)
@@ -40,7 +40,7 @@
arma::vec observation("0 0 1 0 0");
arma::Col<size_t> states;
- hmm.Viterbi(observation, states);
+ hmm.Predict(observation, states);
// Check each state.
BOOST_REQUIRE_EQUAL(states[0], 0); // Rain.
@@ -73,7 +73,7 @@
arma::vec observation("2 2 1 0 1 3 2 0 0");
arma::Col<size_t> states;
- hmm.Viterbi(observation, states);
+ hmm.Predict(observation, states);
// Most probable path is HHHLLLLLL.
BOOST_REQUIRE_EQUAL(states[0], 1);
@@ -151,7 +151,7 @@
observations.push_back("0 0 0 0 0 0 0 0 0 0 0 0");
observations.push_back("0 0 0 0 0 0 0 0 0 0");
- hmm.EstimateModel(observations);
+ hmm.Train(observations);
BOOST_REQUIRE_CLOSE(hmm.Emission()(0, 0), 1.0, 1e-5);
BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
@@ -189,7 +189,7 @@
observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
- hmm.EstimateModel(observations);
+ hmm.Train(observations);
BOOST_REQUIRE_CLOSE(hmm.Emission()(0, 0), 0.5, 1e-5);
BOOST_REQUIRE_CLOSE(hmm.Emission()(1, 0), 0.5, 1e-5);
@@ -272,7 +272,7 @@
out.col(i) = observations[i];
data::Save("out.csv", out);
- hmm.EstimateModel(observations);
+ hmm.Train(observations);
// Only require 0.75% tolerance, because this is a little fuzzier.
BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 0.75);
@@ -290,4 +290,100 @@
BOOST_REQUIRE_CLOSE(hmm.Emission()(3, 1), 0.8, 0.75);
}
+BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
+{
+ // Generate a random Markov model with 3 hidden states and 6 observations.
+ arma::mat transition;
+ arma::mat emission;
+
+ transition.randu(3, 3);
+ emission.randu(6, 3);
+
+ // Normalize so they are correct transition and emission matrices.
+ for (size_t col = 0; col < 3; col++)
+ {
+ transition.col(col) /= accu(transition.col(col));
+ emission.col(col) /= accu(emission.col(col));
+ }
+
+ // Now generate sequences.
+ size_t obsNum = 250;
+ size_t obsLen = 800;
+
+ std::vector<arma::vec> observations(obsNum);
+ std::vector<arma::Col<size_t> > states(obsNum);
+
+ for (size_t n = 0; n < obsNum; n++)
+ {
+ observations[n].set_size(obsLen);
+ states[n].set_size(obsLen);
+
+ // Random starting state.
+ states[n][0] = rand() % 3;
+
+ // Random starting observation.
+ double obs = (double) rand() / (double) RAND_MAX;
+ double sumProb = 0;
+ for (size_t em = 0; em < 6; em++)
+ {
+ sumProb += emission(em, states[n][0]);
+ if (sumProb > obs)
+ {
+ observations[n][0] = em;
+ break;
+ }
+ }
+
+ // Now the rest of the observations.
+ for (size_t t = 1; t < obsLen; t++)
+ {
+ // Choose random numbers for state transition and for emission transition.
+ double obs = (double) rand() / (double) RAND_MAX;
+ double state = (double) rand() / (double) RAND_MAX;
+
+ // Decide next state.
+ double sumProb = 0;
+ for (size_t st = 0; st < 3; st++)
+ {
+ sumProb += transition(st, states[n][t - 1]);
+ if (sumProb > state)
+ {
+ states[n][t] = st;
+ break;
+ }
+ }
+
+ // Decide observation.
+ sumProb = 0;
+ for (size_t em = 0; em < 6; em++)
+ {
+ sumProb += emission(em, states[n][t]);
+ if (sumProb > obs)
+ {
+ observations[n][t] = em;
+ break;
+ }
+ }
+ }
+ }
+
+ // Now that our data is generated, we give the HMM the labeled data to train
+ // on.
+ HMM<int> hmm(3, 6);
+
+ hmm.Train(observations, states);
+
+ // We can't use % tolerance here because percent error increases as the actual
+ // value gets very small. So, instead, we just ensure that every value is no
+ // more than 0.004 away from the actual value.
+ for (size_t row = 0; row < hmm.Transition().n_rows; row++)
+ for (size_t col = 0; col < hmm.Transition().n_cols; col++)
+ BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
+ 0.004);
+
+ for (size_t row = 0; row < hmm.Emission().n_rows; row++)
+ for (size_t col = 0; col < hmm.Emission().n_cols; col++)
+ BOOST_REQUIRE_SMALL(hmm.Emission()(row, col) - emission(row, col), 0.004);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list