[mlpack-svn] r14906 - mlpack/trunk/src/mlpack/methods/hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 15 17:16:37 EDT 2013
Author: rcurtin
Date: 2013-04-15 17:16:37 -0400 (Mon, 15 Apr 2013)
New Revision: 14906
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
Log:
Allow parameterization of the tolerance of the Baum-Welch algorithm.
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2013-04-15 21:06:58 UTC (rev 14905)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm.hpp 2013-04-15 21:16:37 UTC (rev 14906)
@@ -84,11 +84,17 @@
* observations is taken from the emissions variable, so it is important that
* the given default emission distribution is set with the correct
* dimensionality. Alternately, set the dimensionality with Dimensionality().
+ * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
+ * be set.
*
* @param states Number of states.
* @param emissions Default distribution for emissions.
+ * @param tolerance Tolerance for convergence of training algorithm
+ * (Baum-Welch).
*/
- HMM(const size_t states, const Distribution emissions);
+ HMM(const size_t states,
+ const Distribution emissions,
+ const double tolerance = 1e-5);
/**
* Create the Hidden Markov Model with the given transition matrix and the
@@ -103,10 +109,17 @@
* The emission matrix should be such that E(i, j) is the probability of
* emission i while in state j. The columns of the matrix should sum to 1.
*
+ * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
+ * be set.
+ *
* @param transition Transition matrix.
* @param emission Emission distributions.
+ * @param tolerance Tolerance for convergence of training algorithm
+ * (Baum-Welch).
*/
- HMM(const arma::mat& transition, const std::vector<Distribution>& emission);
+ HMM(const arma::mat& transition,
+ const std::vector<Distribution>& emission,
+ const double tolerance = 1e-5);
/**
* Train the model using the Baum-Welch algorithm, with only the given
@@ -123,6 +136,11 @@
* with labeled data first, and then continue to train the model using this
* overload of Train() with unlabeled data.
*
+ * The tolerance of the Baum-Welch algorithm can be set either in the
+ * constructor or with the Tolerance() method. When the change in
+ * log-likelihood of the model between iterations is less than the tolerance,
+ * the Baum-Welch algorithm terminates.
+ *
* @note
* Train() can be called multiple times with different sequences; each time it
* is called, it uses the current parameters of the HMM as a starting point
@@ -247,6 +265,11 @@
//! Set the dimensionality of observations.
size_t& Dimensionality() { return dimensionality; }
+ //! Get the tolerance of the Baum-Welch algorithm.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance of the Baum-Welch algorithm.
+ double& Tolerance() { return tolerance; }
+
private:
// Helper functions.
@@ -287,6 +310,9 @@
//! Dimensionality of observations.
size_t dimensionality;
+
+ //! Tolerance of Baum-Welch algorithm.
+ double tolerance;
};
}; // namespace hmm
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2013-04-15 21:06:58 UTC (rev 14905)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp 2013-04-15 21:16:37 UTC (rev 14906)
@@ -19,10 +19,13 @@
* given number of emission states.
*/
template<typename Distribution>
-HMM<Distribution>::HMM(const size_t states, const Distribution emissions) :
+HMM<Distribution>::HMM(const size_t states,
+ const Distribution emissions,
+ const double tolerance) :
transition(arma::ones<arma::mat>(states, states) / (double) states),
emission(states, /* default distribution */ emissions),
- dimensionality(emissions.Dimensionality())
+ dimensionality(emissions.Dimensionality()),
+ tolerance(tolerance)
{ /* nothing to do */ }
/**
@@ -31,9 +34,11 @@
*/
template<typename Distribution>
HMM<Distribution>::HMM(const arma::mat& transition,
- const std::vector<Distribution>& emission) :
+ const std::vector<Distribution>& emission,
+ const double tolerance) :
transition(transition),
- emission(emission)
+ emission(emission),
+ tolerance(tolerance)
{
// Set the dimensionality, if we can.
if (emission.size() > 0)
@@ -160,7 +165,7 @@
Log::Debug << "Iteration " << iter << ": log-likelihood " << loglik
<< std::endl;
- if (fabs(oldLoglik - loglik) < 1e-5)
+ if (std::abs(oldLoglik - loglik) < tolerance)
{
Log::Debug << "Converged after " << iter << " iterations." << std::endl;
break;
More information about the mlpack-svn
mailing list