[mlpack-svn] r10265 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 14 00:18:33 EST 2011
Author: rcurtin
Date: 2011-11-14 00:18:33 -0500 (Mon, 14 Nov 2011)
New Revision: 10265
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
Implement Generate() functionality.
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-11-14 04:21:38 UTC (rev 10264)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2011-11-14 05:18:33 UTC (rev 10265)
@@ -73,19 +73,6 @@
const std::vector<arma::Col<size_t> >& stateSeq);
/**
- * Generate a random data sequence of the given length. The data sequence is
- * stored in the data_sequence parameter, and the state sequence is stored in
- * the state_sequence parameter.
- *
- * @param length Length of random sequence to generate.
- * @param data_sequence Vector to store data in.
- * @param state_sequence Vector to store states in.
- */
- void GenerateSequence(const size_t length,
- arma::vec& data_sequence,
- arma::vec& state_sequence) const;
-
- /**
* Estimate the probabilities of each hidden state at each time step for each
* given data observation.
*/
@@ -96,6 +83,21 @@
arma::vec& scale_vec) const;
/**
+ * Generate a random data sequence of the given length. The data sequence is
+ * stored in the data_sequence parameter, and the state sequence is stored in
+ * the state_sequence parameter.
+ *
+ * @param length Length of random sequence to generate.
+ * @param dataSequence Vector to store data in.
+ * @param stateSequence Vector to store states in.
+ * @param startState Hidden state to start sequence in (default 0).
+ */
+ void Generate(const size_t length,
+ arma::vec& dataSequence,
+ arma::Col<size_t>& stateSequence,
+ const size_t startState = 0) const;
+
+ /**
* Compute the log-likelihood of a sequence.
*/
double LogLikelihood(const arma::vec& data_seq) const;
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-14 04:21:38 UTC (rev 10264)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2011-11-14 05:18:33 UTC (rev 10265)
@@ -201,6 +201,77 @@
}
/**
+ * Generate a random data sequence of a given length. The data sequence is
+ * stored in the dataSequence parameter, and the state sequence is stored in
+ * the stateSequence parameter.
+ */
+template<typename Distribution>
+void HMM<Distribution>::Generate(const size_t length,
+ arma::vec& dataSequence,
+ arma::Col<size_t>& stateSequence,
+ const size_t startState) const
+{
+ // Set vectors to the right size.
+ stateSequence.set_size(length);
+ dataSequence.set_size(length);
+
+ // Set start state (default is 0).
+ stateSequence[0] = startState;
+
+ // Choose first emission state.
+ double randValue = (double) rand() / (double) RAND_MAX;
+
+ // We just have to find where our random value sits in the probability
+ // distribution of emissions for our starting state.
+ double probSum = 0;
+ for (size_t em = 0; em < emission.n_rows; em++)
+ {
+ probSum += emission(em, startState);
+ if (randValue <= probSum)
+ {
+ dataSequence[0] = em;
+ break;
+ }
+ }
+
+ // Now choose the states and emissions for the rest of the sequence.
+ for (size_t t = 1; t < length; t++)
+ {
+ // First choose the hidden state.
+ randValue = (double) rand() / (double) RAND_MAX;
+
+ // Now find where our random value sits in the probability distribution of
+ // state changes.
+ probSum = 0;
+ for (size_t st = 0; st < transition.n_rows; st++)
+ {
+ probSum += transition(st, stateSequence[t - 1]);
+ if (randValue <= probSum)
+ {
+ stateSequence[t] = st;
+ break;
+ }
+ }
+
+ // Now choose the emission.
+ randValue = (double) rand() / (double) RAND_MAX;
+
+ // Now find where our random value sits in the probability distribution of
+ // emissions for the state we just chose.
+ probSum = 0;
+ for (size_t em = 0; em < emission.n_rows; em++)
+ {
+ probSum += emission(em, stateSequence[t]);
+ if (randValue <= probSum)
+ {
+ dataSequence[t] = em;
+ break;
+ }
+ }
+ }
+}
+
+/**
* Compute the most probable hidden state sequence for the given observation
* using the Viterbi algorithm. Returns the log-likelihood of the most likely
* sequence.
More information about the mlpack-svn
mailing list