[mlpack-svn] r10578 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 04:06:08 EST 2011
Author: rcurtin
Date: 2011-12-06 04:06:07 -0500 (Tue, 06 Dec 2011)
New Revision: 10578
Modified:
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Test HMMs with GMMs; a simple Predict() test. Still need Estimate(),
Estimate(...), LogLikelihood(), and maybe some other tests. I am reasonably
confident the whole thing works correctly.
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-12-06 09:01:22 UTC (rev 10577)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-12-06 09:06:07 UTC (rev 10578)
@@ -7,11 +7,14 @@
#include <mlpack/methods/hmm/hmm.hpp>
#include <mlpack/methods/hmm/distributions/discrete_distribution.hpp>
#include <mlpack/methods/hmm/distributions/gaussian_distribution.hpp>
+#include <mlpack/methods/gmm/gmm.hpp>
+
#include <boost/test/unit_test.hpp>
using namespace mlpack;
using namespace mlpack::hmm;
using namespace mlpack::distribution;
+using namespace mlpack::gmm;
BOOST_AUTO_TEST_SUITE(HMMTest);
@@ -719,5 +722,73 @@
}
}
+/**
+ * Test that HMMs work with Gaussian mixture models. We'll try putting in a
+ * simple model by hand and making sure that prediction of observation sequences
+ * works correctly.
+ */
+BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
+{
+ srand(time(NULL));
+
+ // We will use two GMMs; one with two components and one with three.
+ std::vector<GMM> gmms(2);
+ gmms[0] = GMM(2, 2);
+ gmms[0].Weights() = arma::vec("0.75 0.25");
+
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmms[0].Means()[0] = arma::vec("4.25 3.10");
+ gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmms[0].Means()[1] = arma::vec("7.10 5.01");
+ gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+ gmms[1] = GMM(3, 2);
+ gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
+
+ gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+ gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+ gmms[1].Means()[1] = arma::vec("-4.25 -7.12");
+ gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+ gmms[1].Means()[2] = arma::vec("-6.15 -2.00");
+ gmms[1].Covariances()[2] = arma::mat("1.00 0.80; 0.80 1.00");
+
+ // Transition matrix.
+ arma::mat trans("0.30 0.50;"
+ "0.70 0.50");
+
+ // Now build the model.
+ HMM<GMM> hmm(trans, gmms);
+
+ // Make a sequence of observations.
+ arma::mat observations(2, 1000);
+ arma::Col<size_t> states(1000);
+ states[0] = 0;
+ observations.col(0) = gmms[0].Random();
+
+ for (size_t i = 1; i < 1000; i++)
+ {
+ double randValue = (double) rand() / (double) RAND_MAX;
+
+ if (randValue <= trans(0, states[i - 1]))
+ states[i] = 0;
+ else
+ states[i] = 1;
+
+ observations.col(i) = gmms[states[i]].Random();
+ }
+
+ // Run the prediction.
+ arma::Col<size_t> predictions;
+ hmm.Predict(observations, predictions);
+
+ // Check that the predictions were correct.
+ for (size_t i = 0; i < 1000; i++)
+ BOOST_REQUIRE_EQUAL(predictions[i], states[i]);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list