[mlpack-svn] r10273 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 14 12:17:41 EST 2011
Author: rcurtin
Date: 2011-11-14 12:17:41 -0500 (Mon, 14 Nov 2011)
New Revision: 10273
Modified:
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Log:
Add one more test for Generate(). The tolerance is a bit higher than I would
like, but that's okay.
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-14 16:19:14 UTC (rev 10272)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp 2011-11-14 17:17:41 UTC (rev 10273)
@@ -426,4 +426,56 @@
BOOST_REQUIRE_CLOSE(stateProb[1], 0.50, 2.0);
}
+/***
+ * More complex test for Generate().
+ */
+BOOST_AUTO_TEST_CASE(DiscreteHMMGenerateTest)
+{
+ srand(time(NULL));
+ // 6 emissions, 4 states. Random transition and emission probability.
+ arma::mat transition(4, 4);
+ arma::mat emission(6, 4);
+
+ transition.randu();
+ emission.randu();
+
+ // Normalize matrices.
+ for (size_t col = 0; col < 4; col++)
+ {
+ transition.col(col) /= accu(transition.col(col));
+ emission.col(col) /= accu(emission.col(col));
+ }
+
+ // Create HMM object.
+ HMM<int> hmm(transition, emission);
+
+ // We'll create a bunch of sequences.
+ int numSeq = 400;
+ int numObs = 3000;
+ std::vector<arma::vec> sequences(numSeq);
+ std::vector<arma::Col<size_t> > states(numSeq);
+ for (int i = 0; i < numSeq; i++)
+ {
+ // Random starting state.
+ size_t startState = rand() % 4;
+
+ hmm.Generate(numObs, sequences[i], states[i], startState);
+ }
+
+ // Now we will calculate the full probabilities.
+ HMM<int> hmm2(4, 6);
+ hmm2.Train(sequences, states);
+
+ // Check that training gives the same result. 8% tolerance.
+ for (size_t row = 0; row < 4; row++)
+ for (size_t col = 0; col < 4; col++)
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(row, col),
+ hmm2.Transition()(row, col), 8.0);
+
+ for (size_t row = 0; row < 6; row++)
+ for (size_t col = 0; col < 4; col++)
+ BOOST_REQUIRE_CLOSE(hmm.Emission()(row, col), hmm2.Emission()(row, col),
+ 8.0);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list