[mlpack-svn] r10578 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 04:06:08 EST 2011


Author: rcurtin
Date: 2011-12-06 04:06:07 -0500 (Tue, 06 Dec 2011)
New Revision: 10578

Modified:
   mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Test HMMs with GMMs; a simple Predict() test.  Still need Estimate(),
Estimate(...), LogLikelihood(), and maybe some other tests.  I am reasonably
confident the whole thing works correctly.


Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-12-06 09:01:22 UTC (rev 10577)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-12-06 09:06:07 UTC (rev 10578)
@@ -7,11 +7,14 @@
 #include <mlpack/methods/hmm/hmm.hpp>
 #include <mlpack/methods/hmm/distributions/discrete_distribution.hpp>
 #include <mlpack/methods/hmm/distributions/gaussian_distribution.hpp>
+#include <mlpack/methods/gmm/gmm.hpp>
+
 #include <boost/test/unit_test.hpp>
 
 using namespace mlpack;
 using namespace mlpack::hmm;
 using namespace mlpack::distribution;
+using namespace mlpack::gmm;
 
 BOOST_AUTO_TEST_SUITE(HMMTest);
 
@@ -719,5 +722,73 @@
   }
 }
 
+/**
+ * Test that HMMs work with Gaussian mixture models.  We'll try putting in a
+ * simple model by hand and making sure that prediction of observation sequences
+ * works correctly.
+ */
+BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
+{
+  srand(time(NULL));
+
+  // We will use two GMMs; one with two components and one with three.
+  std::vector<GMM> gmms(2);
+  gmms[0] = GMM(2, 2);
+  gmms[0].Weights() = arma::vec("0.75 0.25");
+
+  // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+  gmms[0].Means()[0] = arma::vec("4.25 3.10");
+  gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+  // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+  gmms[0].Means()[1] = arma::vec("7.10 5.01");
+  gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+  gmms[1] = GMM(3, 2);
+  gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
+
+  gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+  gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+  gmms[1].Means()[1] = arma::vec("-4.25 -7.12");
+  gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+  gmms[1].Means()[2] = arma::vec("-6.15 -2.00");
+  gmms[1].Covariances()[2] = arma::mat("1.00 0.80; 0.80 1.00");
+
+  // Transition matrix.
+  arma::mat trans("0.30 0.50;"
+                  "0.70 0.50");
+
+  // Now build the model.
+  HMM<GMM> hmm(trans, gmms);
+
+  // Make a sequence of observations.
+  arma::mat observations(2, 1000);
+  arma::Col<size_t> states(1000);
+  states[0] = 0;
+  observations.col(0) = gmms[0].Random();
+
+  for (size_t i = 1; i < 1000; i++)
+  {
+    double randValue = (double) rand() / (double) RAND_MAX;
+
+    if (randValue <= trans(0, states[i - 1]))
+      states[i] = 0;
+    else
+      states[i] = 1;
+
+    observations.col(i) = gmms[states[i]].Random();
+  }
+
+  // Run the prediction.
+  arma::Col<size_t> predictions;
+  hmm.Predict(observations, predictions);
+
+  // Check that the predictions were correct.
+  for (size_t i = 0; i < 1000; i++)
+    BOOST_REQUIRE_EQUAL(predictions[i], states[i]);
+}
+
 BOOST_AUTO_TEST_SUITE_END();
 




More information about the mlpack-svn mailing list