[mlpack-svn] r10265 - mlpack/trunk/src/mlpack/methods/hmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Nov 14 00:18:33 EST 2011


Author: rcurtin
Date: 2011-11-14 00:18:33 -0500 (Mon, 14 Nov 2011)
New Revision: 10265

Modified:
   mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
   mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
Implement Generate() functionality.


Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp	2011-11-14 04:21:38 UTC (rev 10264)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp	2011-11-14 05:18:33 UTC (rev 10265)
@@ -73,19 +73,6 @@
              const std::vector<arma::Col<size_t> >& stateSeq);
 
   /**
-   * Generate a random data sequence of the given length.  The data sequence is
-   * stored in the data_sequence parameter, and the state sequence is stored in
-   * the state_sequence parameter.
-   *
-   * @param length Length of random sequence to generate.
-   * @param data_sequence Vector to store data in.
-   * @param state_sequence Vector to store states in.
-   */
-  void GenerateSequence(const size_t length,
-                        arma::vec& data_sequence,
-                        arma::vec& state_sequence) const;
-
-  /**
    * Estimate the probabilities of each hidden state at each time step for each
    * given data observation.
    */
@@ -96,6 +83,21 @@
                   arma::vec& scale_vec) const;
 
   /**
+   * Generate a random data sequence of the given length.  The data sequence is
+   * stored in the data_sequence parameter, and the state sequence is stored in
+   * the state_sequence parameter.
+   *
+   * @param length Length of random sequence to generate.
+   * @param dataSequence Vector to store data in.
+   * @param stateSequence Vector to store states in.
+   * @param startState Hidden state to start sequence in (default 0).
+   */
+  void Generate(const size_t length,
+                arma::vec& dataSequence,
+                arma::Col<size_t>& stateSequence,
+                const size_t startState = 0) const;
+
+  /**
    * Compute the log-likelihood of a sequence.
    */
   double LogLikelihood(const arma::vec& data_seq) const;

Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	2011-11-14 04:21:38 UTC (rev 10264)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp	2011-11-14 05:18:33 UTC (rev 10265)
@@ -201,6 +201,77 @@
 }
 
 /**
+ * Generate a random data sequence of a given length.  The data sequence is
+ * stored in the dataSequence parameter, and the state sequence is stored in
+ * the stateSequence parameter.
+ */
+template<typename Distribution>
+void HMM<Distribution>::Generate(const size_t length,
+                                 arma::vec& dataSequence,
+                                 arma::Col<size_t>& stateSequence,
+                                 const size_t startState) const
+{
+  // Set vectors to the right size.
+  stateSequence.set_size(length);
+  dataSequence.set_size(length);
+
+  // Set start state (default is 0).
+  stateSequence[0] = startState;
+
+  // Choose first emission state.
+  double randValue = (double) rand() / (double) RAND_MAX;
+
+  // We just have to find where our random value sits in the probability
+  // distribution of emissions for our starting state.
+  double probSum = 0;
+  for (size_t em = 0; em < emission.n_rows; em++)
+  {
+    probSum += emission(em, startState);
+    if (randValue <= probSum)
+    {
+      dataSequence[0] = em;
+      break;
+    }
+  }
+
+  // Now choose the states and emissions for the rest of the sequence.
+  for (size_t t = 1; t < length; t++)
+  {
+    // First choose the hidden state.
+    randValue = (double) rand() / (double) RAND_MAX;
+
+    // Now find where our random value sits in the probability distribution of
+    // state changes.
+    probSum = 0;
+    for (size_t st = 0; st < transition.n_rows; st++)
+    {
+      probSum += transition(st, stateSequence[t - 1]);
+      if (randValue <= probSum)
+      {
+        stateSequence[t] = st;
+        break;
+      }
+    }
+
+    // Now choose the emission.
+    randValue = (double) rand() / (double) RAND_MAX;
+
+    // Now find where our random value sits in the probability distribution of
+    // emissions for the state we just chose.
+    probSum = 0;
+    for (size_t em = 0; em < emission.n_rows; em++)
+    {
+      probSum += emission(em, stateSequence[t]);
+      if (randValue <= probSum)
+      {
+        dataSequence[t] = em;
+        break;
+      }
+    }
+  }
+}
+
+/**
  * Compute the most probable hidden state sequence for the given observation
  * using the Viterbi algorithm. Returns the log-likelihood of the most likely
  * sequence.




More information about the mlpack-svn mailing list