[mlpack-svn] r10349 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 22 14:13:20 EST 2011
Author: rcurtin
Date: 2011-11-22 14:13:20 -0500 (Tue, 22 Nov 2011)
New Revision: 10349
Modified:
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Test Generate() with a Gaussian HMM.
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-22 16:09:45 UTC (rev 10348)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-22 19:13:20 UTC (rev 10349)
@@ -707,4 +707,56 @@
}
}
+/**
+ * Make sure that a random sequence generated by a Gaussian HMM fits the
+ * distribution correctly.
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMGenerateTest)
+{
+ // Our distribution will have three two-dimensional output Gaussians.
+ HMM<GaussianDistribution> hmm(3, GaussianDistribution(2));
+ hmm.Transition() = arma::mat("0.4 0.6 0.8; 0.2 0.2 0.1; 0.4 0.2 0.1");
+ hmm.Emission()[0] = GaussianDistribution("0.0 0.0", "1.0 0.0; 0.0 1.0");
+ hmm.Emission()[1] = GaussianDistribution("2.0 2.0", "1.0 0.5; 0.5 1.2");
+ hmm.Emission()[2] = GaussianDistribution("-2.0 1.0", "2.0 0.1; 0.1 1.0");
+
+ // Now we will generate a long sequence.
+ std::vector<std::vector<arma::vec> > observations(1);
+ std::vector<std::vector<size_t> > states(1);
+
+ // Start in state 1 (no reason).
+ hmm.Generate(10000, observations[0], states[0], 1);
+
+ HMM<GaussianDistribution> hmm2(3, GaussianDistribution(2));
+
+ // Now estimate the HMM from the generated sequence.
+ hmm2.Train(observations, states);
+
+ // Check that the estimated matrices are the same.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - hmm2.Transition()(row,
+ col), 0.03);
+
+ // Check that each Gaussian is the same.
+ for (size_t em = 0; em < 3; em++)
+ {
+ // Check that the mean is the same.
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(0) -
+ hmm2.Emission()[em].Mean()(0), 0.07);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(1) -
+ hmm2.Emission()[em].Mean()(1), 0.07);
+
+ // Check that the covariances are the same.
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 0) -
+ hmm2.Emission()[em].Covariance()(0, 0), 0.1);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 1) -
+ hmm2.Emission()[em].Covariance()(0, 1), 0.1);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 0) -
+ hmm2.Emission()[em].Covariance()(1, 0), 0.1);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 1) -
+ hmm2.Emission()[em].Covariance()(1, 1), 0.1);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list