[mlpack-svn] r10338 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 21 10:48:50 EST 2011


Author: rcurtin
Date: 2011-11-21 10:48:49 -0500 (Mon, 21 Nov 2011)
New Revision: 10338

Modified:
   mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Add a first test for a Gaussian HMM.


Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-21 15:48:21 UTC (rev 10337)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-21 15:48:49 UTC (rev 10338)
@@ -6,6 +6,7 @@
 #include <mlpack/core.h>
 #include <mlpack/methods/hmm/hmm.hpp>
 #include <mlpack/methods/hmm/distributions/discrete_distribution.hpp>
+#include <mlpack/methods/hmm/distributions/gaussian_distribution.hpp>
 #include <boost/test/unit_test.hpp>
 
 using namespace mlpack;
@@ -524,4 +525,61 @@
   BOOST_REQUIRE_CLOSE(hmm.LogLikelihood(seq), -24.51556128368, 1e-5);
 }
 
+/**
+ * A simple test to make sure HMMs with Gaussian output distributions work.
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMSimpleTest)
+{
+  // We'll have two Gaussians, far away from each other, one corresponding to
+  // each state.
+  //  E(0) ~ N([ 5.0  5.0], eye(2)).
+  //  E(1) ~ N([-5.0 -5.0], eye(2)).
+  // The transition matrix is simple:
+  //  T = [[0.75 0.25]
+  //       [0.25 0.75]]
+  GaussianDistribution g1("5.0 5.0", "1.0 0.0; 0.0 1.0");
+  GaussianDistribution g2("-5.0 -5.0", "1.0 0.0; 0.0 1.0");
+
+  arma::mat transition("0.75 0.25; 0.25 0.75");
+
+  std::vector<GaussianDistribution> emission;
+  emission.push_back(g1);
+  emission.push_back(g2);
+
+  HMM<GaussianDistribution> hmm(transition, emission);
+
+  // Now, generate some sequences.
+  std::vector<arma::vec> observations(1000);
+  std::vector<size_t> classes(1000);
+
+  // 1000-observations sequence.
+  classes[0] = 0;
+  observations[0] = g1.Random();
+  for (size_t i = 1; i < 1000; i++)
+  {
+    double randValue = (double) rand() / (double) RAND_MAX;
+
+    if (randValue > 0.75) // Then we change state.
+      classes[i] = (classes[i - 1] + 1) % 2;
+    else
+      classes[i] = classes[i - 1];
+
+    if (classes[i] == 0)
+      observations[i] = g1.Random();
+    else
+      observations[i] = g2.Random();
+  }
+
+  // Now predict the sequence.
+  std::vector<size_t> predictedClasses;
+  arma::mat stateProb;
+
+  hmm.Predict(observations, predictedClasses);
+  hmm.Estimate(observations, stateProb);
+
+  // Check that each prediction is right.
+  for (size_t i = 0; i < 1000; i++)
+    BOOST_REQUIRE_EQUAL(predictedClasses[i], classes[i]);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list