[mlpack-svn] r10273 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 14 12:17:41 EST 2011


Author: rcurtin
Date: 2011-11-14 12:17:41 -0500 (Mon, 14 Nov 2011)
New Revision: 10273

Modified:
   mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Add one more test for Generate().  The tolerance is a bit higher than I would
like, but that's okay.


Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-14 16:19:14 UTC (rev 10272)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp	2011-11-14 17:17:41 UTC (rev 10273)
@@ -426,4 +426,56 @@
   BOOST_REQUIRE_CLOSE(stateProb[1], 0.50, 2.0);
 }
 
+/***
+ * More complex test for Generate().
+ */
+BOOST_AUTO_TEST_CASE(DiscreteHMMGenerateTest)
+{
+  srand(time(NULL));
+  // 6 emissions, 4 states.  Random transition and emission probability.
+  arma::mat transition(4, 4);
+  arma::mat emission(6, 4);
+
+  transition.randu();
+  emission.randu();
+
+  // Normalize matrices.
+  for (size_t col = 0; col < 4; col++)
+  {
+    transition.col(col) /= accu(transition.col(col));
+    emission.col(col) /= accu(emission.col(col));
+  }
+
+  // Create HMM object.
+  HMM<int> hmm(transition, emission);
+
+  // We'll create a bunch of sequences.
+  int numSeq = 400;
+  int numObs = 3000;
+  std::vector<arma::vec> sequences(numSeq);
+  std::vector<arma::Col<size_t> > states(numSeq);
+  for (int i = 0; i < numSeq; i++)
+  {
+    // Random starting state.
+    size_t startState = rand() % 4;
+
+    hmm.Generate(numObs, sequences[i], states[i], startState);
+  }
+
+  // Now we will calculate the full probabilities.
+  HMM<int> hmm2(4, 6);
+  hmm2.Train(sequences, states);
+
+  // Check that training gives the same result.  8% tolerance.
+  for (size_t row = 0; row < 4; row++)
+    for (size_t col = 0; col < 4; col++)
+      BOOST_REQUIRE_CLOSE(hmm.Transition()(row, col),
+          hmm2.Transition()(row, col), 8.0);
+
+  for (size_t row = 0; row < 6; row++)
+    for (size_t col = 0; col < 4; col++)
+      BOOST_REQUIRE_CLOSE(hmm.Emission()(row, col), hmm2.Emission()(row, col),
+          8.0);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list