[mlpack-svn] r12319 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 11 15:35:23 EDT 2012


Author: rcurtin
Date: 2012-04-11 15:35:23 -0400 (Wed, 11 Apr 2012)
New Revision: 12319

Modified:
   mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
The fabled GMM-HMM training test.  The hardest of all tests, only the strongest
training implementations survive the double-sucker-punch of GMM followed by GMM.
A more rigorous drill sergeant might require tighter tolerances, but
unfortunately perfection can't be expected out of the EM algorithm every time.


Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2012-04-11 19:15:13 UTC (rev 12318)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2012-04-11 19:35:23 UTC (rev 12319)
@@ -738,7 +738,7 @@
 BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
 {
   // We will use two GMMs; one with two components and one with three.
-  std::vector<GMM</*EMFit<>*/ > > gmms(2);
+  std::vector<GMM<> > gmms(2);
   gmms[0] = GMM<>(2, 2);
   gmms[0].Weights() = arma::vec("0.75 0.25");
 
@@ -800,7 +800,140 @@
  * Test that GMM-based HMMs can train on models correctly using labeled training
  * data.
  */
+BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
+{
+  srand(time(NULL));
 
+  // We will use two GMMs; one with two components and one with three.
+  std::vector<GMM<> > gmms(2, GMM<>(2, 2));
+  gmms[0].Weights() = arma::vec("0.3 0.7");
 
+  // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+  gmms[0].Means()[0] = arma::vec("4.25 3.10");
+  gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+  // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+  gmms[0].Means()[1] = arma::vec("7.10 5.01");
+  gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+  gmms[1].Weights() = arma::vec("0.20 0.80");
+
+  gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+  gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+  gmms[1].Means()[1] = arma::vec("-4.25 -2.12");
+  gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+  // Transition matrix.
+  arma::mat transMat("0.40 0.60;"
+                     "0.60 0.40");
+
+  // Make a sequence of observations.
+  std::vector<arma::mat> observations(5, arma::mat(2, 2500));
+  std::vector<arma::Col<size_t> > states(5, arma::Col<size_t>(2500));
+  for (size_t obs = 0; obs < 5; obs++)
+  {
+    states[obs][0] = 0;
+    observations[obs].col(0) = gmms[0].Random();
+
+    for (size_t i = 1; i < 2500; i++)
+    {
+      double randValue = (double) rand() / (double) RAND_MAX;
+
+      if (randValue <= transMat(0, states[obs][i - 1]))
+        states[obs][i] = 0;
+      else
+        states[obs][i] = 1;
+
+      observations[obs].col(i) = gmms[states[obs][i]].Random();
+    }
+  }
+
+  // Set up the GMM for training.
+  HMM<GMM<> > hmm(2, GMM<>(2, 2));
+
+  // Train the HMM.
+  hmm.Train(observations, states);
+
+  // Check the results.  Use absolute tolerances instead of percentages.
+  BOOST_REQUIRE_SMALL(hmm.Transition()(0, 0) - transMat(0, 0), 0.02);
+  BOOST_REQUIRE_SMALL(hmm.Transition()(0, 1) - transMat(0, 1), 0.02);
+  BOOST_REQUIRE_SMALL(hmm.Transition()(1, 0) - transMat(1, 0), 0.02);
+  BOOST_REQUIRE_SMALL(hmm.Transition()(1, 1) - transMat(1, 1), 0.02);
+
+  // Now the emission probabilities (the GMMs).
+  // We have to sort each GMM for comparison.
+  arma::uvec sortedIndices = sort_index(hmm.Emission()[0].Weights());
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[0]] -
+      gmms[0].Weights()[0], 0.08);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[1]] -
+      gmms[0].Weights()[1], 0.08);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][0] -
+      gmms[0].Means()[0][0], 0.15);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][1] -
+      gmms[0].Means()[0][1], 0.15);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][0] -
+      gmms[0].Means()[1][0], 0.15);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][1] -
+      gmms[0].Means()[1][1], 0.15);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 0) -
+      gmms[0].Covariances()[0](0, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 1) -
+      gmms[0].Covariances()[0](0, 1), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 0) -
+      gmms[0].Covariances()[0](1, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 1) -
+      gmms[0].Covariances()[0](1, 1), 0.3);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 0) -
+      gmms[0].Covariances()[1](0, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 1) -
+      gmms[0].Covariances()[1](0, 1), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 0) -
+      gmms[0].Covariances()[1](1, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 1) -
+      gmms[0].Covariances()[1](1, 1), 0.3);
+
+  // Sort the GMM.
+  sortedIndices = sort_index(hmm.Emission()[1].Weights());
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[0]] -
+      gmms[1].Weights()[0], 0.08);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[1]] -
+      gmms[1].Weights()[1], 0.08);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][0] -
+      gmms[1].Means()[0][0], 0.15);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][1] -
+      gmms[1].Means()[0][1], 0.15);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][0] -
+      gmms[1].Means()[1][0], 0.15);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][1] -
+      gmms[1].Means()[1][1], 0.15);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 0) -
+      gmms[1].Covariances()[0](0, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 1) -
+      gmms[1].Covariances()[0](0, 1), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 0) -
+      gmms[1].Covariances()[0](1, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 1) -
+      gmms[1].Covariances()[0](1, 1), 0.3);
+
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 0) -
+      gmms[1].Covariances()[1](0, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 1) -
+      gmms[1].Covariances()[1](0, 1), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 0) -
+      gmms[1].Covariances()[1](1, 0), 0.3);
+  BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 1) -
+      gmms[1].Covariances()[1](1, 1), 0.3);
+}
+
 BOOST_AUTO_TEST_SUITE_END();
 




More information about the mlpack-svn mailing list