[mlpack-svn] r12319 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 11 15:35:23 EDT 2012
Author: rcurtin
Date: 2012-04-11 15:35:23 -0400 (Wed, 11 Apr 2012)
New Revision: 12319
Modified:
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
The fabled GMM-HMM training test. The hardest of all tests, only the strongest
training implementations survive the double-sucker-punch of GMM followed by GMM.
A more rigorous drill sergeant might require tighter tolerances, but
unfortunately perfection can't be expected out of the EM algorithm every time.
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2012-04-11 19:15:13 UTC (rev 12318)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2012-04-11 19:35:23 UTC (rev 12319)
@@ -738,7 +738,7 @@
BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
{
// We will use two GMMs; one with two components and one with three.
- std::vector<GMM</*EMFit<>*/ > > gmms(2);
+ std::vector<GMM<> > gmms(2);
gmms[0] = GMM<>(2, 2);
gmms[0].Weights() = arma::vec("0.75 0.25");
@@ -800,7 +800,140 @@
* Test that GMM-based HMMs can train on models correctly using labeled training
* data.
*/
+BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
+{
+ srand(time(NULL));
+ // We will use two GMMs; one with two components and one with three.
+ std::vector<GMM<> > gmms(2, GMM<>(2, 2));
+ gmms[0].Weights() = arma::vec("0.3 0.7");
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmms[0].Means()[0] = arma::vec("4.25 3.10");
+ gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmms[0].Means()[1] = arma::vec("7.10 5.01");
+ gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+ gmms[1].Weights() = arma::vec("0.20 0.80");
+
+ gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+ gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+ gmms[1].Means()[1] = arma::vec("-4.25 -2.12");
+ gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+ // Transition matrix.
+ arma::mat transMat("0.40 0.60;"
+ "0.60 0.40");
+
+ // Make a sequence of observations.
+ std::vector<arma::mat> observations(5, arma::mat(2, 2500));
+ std::vector<arma::Col<size_t> > states(5, arma::Col<size_t>(2500));
+ for (size_t obs = 0; obs < 5; obs++)
+ {
+ states[obs][0] = 0;
+ observations[obs].col(0) = gmms[0].Random();
+
+ for (size_t i = 1; i < 2500; i++)
+ {
+ double randValue = (double) rand() / (double) RAND_MAX;
+
+ if (randValue <= transMat(0, states[obs][i - 1]))
+ states[obs][i] = 0;
+ else
+ states[obs][i] = 1;
+
+ observations[obs].col(i) = gmms[states[obs][i]].Random();
+ }
+ }
+
+ // Set up the GMM for training.
+ HMM<GMM<> > hmm(2, GMM<>(2, 2));
+
+ // Train the HMM.
+ hmm.Train(observations, states);
+
+ // Check the results. Use absolute tolerances instead of percentages.
+ BOOST_REQUIRE_SMALL(hmm.Transition()(0, 0) - transMat(0, 0), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(0, 1) - transMat(0, 1), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(1, 0) - transMat(1, 0), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(1, 1) - transMat(1, 1), 0.02);
+
+ // Now the emission probabilities (the GMMs).
+ // We have to sort each GMM for comparison.
+ arma::uvec sortedIndices = sort_index(hmm.Emission()[0].Weights());
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[0]] -
+ gmms[0].Weights()[0], 0.08);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[1]] -
+ gmms[0].Weights()[1], 0.08);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][0] -
+ gmms[0].Means()[0][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][1] -
+ gmms[0].Means()[0][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][0] -
+ gmms[0].Means()[1][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][1] -
+ gmms[0].Means()[1][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 0) -
+ gmms[0].Covariances()[0](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 1) -
+ gmms[0].Covariances()[0](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 0) -
+ gmms[0].Covariances()[0](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 1) -
+ gmms[0].Covariances()[0](1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 0) -
+ gmms[0].Covariances()[1](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 1) -
+ gmms[0].Covariances()[1](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 0) -
+ gmms[0].Covariances()[1](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 1) -
+ gmms[0].Covariances()[1](1, 1), 0.3);
+
+ // Sort the GMM.
+ sortedIndices = sort_index(hmm.Emission()[1].Weights());
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[0]] -
+ gmms[1].Weights()[0], 0.08);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[1]] -
+ gmms[1].Weights()[1], 0.08);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][0] -
+ gmms[1].Means()[0][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][1] -
+ gmms[1].Means()[0][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][0] -
+ gmms[1].Means()[1][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][1] -
+ gmms[1].Means()[1][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 0) -
+ gmms[1].Covariances()[0](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 1) -
+ gmms[1].Covariances()[0](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 0) -
+ gmms[1].Covariances()[0](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 1) -
+ gmms[1].Covariances()[0](1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 0) -
+ gmms[1].Covariances()[1](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 1) -
+ gmms[1].Covariances()[1](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 0) -
+ gmms[1].Covariances()[1](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 1) -
+ gmms[1].Covariances()[1](1, 1), 0.3);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list