[mlpack-svn] r16245 - in mlpack/trunk/src/mlpack: methods/hmm tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Feb 9 22:50:54 EST 2014
Author: michaelfox99
Date: Sun Feb 9 22:50:53 2014
New Revision: 16245
Log:
fix forHMMTest/SimpleDiscreteHMMTestViterbi
Modified:
mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
mlpack/trunk/src/mlpack/tests/hmm_test.cpp
Modified: mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/hmm/hmm_impl.hpp Sun Feb 9 22:50:53 2014
@@ -370,14 +370,13 @@
logStateProb.col(0).zeros();
for (size_t state = 0; state < transition.n_rows; state++)
{
- logStateProb[state] = log(transition(state, 0) *
+ logStateProb[state] = log(transition.unsafe_col(state).max() *
emission[state].Probability(dataSeq.unsafe_col(0)));
stateSeqBack[state] = state;
}
// Store the best first state.
arma::uword index;
- logStateProb.unsafe_col(0).max(index);
for (size_t t = 1; t < dataSeq.n_cols; t++)
{
// Assemble the state probability for this element.
@@ -386,10 +385,9 @@
for (size_t j = 0; j < transition.n_rows; j++)
{
arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
- logStateProb(j, t) = prob.max() +
+ logStateProb(j, t) = prob.max(index) +
log(emission[j].Probability(dataSeq.unsafe_col(t)));
- prob.max(index);
- stateSeqBack(j, t) = index;
+ stateSeqBack(j, t) = index;
}
}
// Backtrack to find most probable state sequence
Modified: mlpack/trunk/src/mlpack/tests/hmm_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/hmm_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/hmm_test.cpp Sun Feb 9 22:50:53 2014
@@ -2,6 +2,21 @@
* @file hmm_test.cpp
*
* Test file for HMMs.
+ *
+ * This file is part of MLPACK 1.0.7.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/hmm/hmm.hpp>
@@ -35,8 +50,8 @@
// [0.1 0.8]] no umbrella
arma::mat transition("0.7 0.3; 0.3 0.7");
std::vector<DiscreteDistribution> emission(2);
- emission[0] = DiscreteDistribution("0.9 0.2");
- emission[1] = DiscreteDistribution("0.1 0.8");
+ emission[0] = DiscreteDistribution("0.9 0.1");
+ emission[1] = DiscreteDistribution("0.2 0.8");
HMM<DiscreteDistribution> hmm(transition, emission);
@@ -46,7 +61,7 @@
arma::mat observation = "0 0 1 0 0";
arma::Col<size_t> states;
hmm.Predict(observation, states);
-
+
// Check each state.
BOOST_REQUIRE_EQUAL(states[0], 0); // Rain.
BOOST_REQUIRE_EQUAL(states[1], 0); // Rain.
@@ -273,15 +288,14 @@
BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 1), 0.5, 2.5);
BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 1), 0.5, 2.5);
- // Widened to 3% tolerance.
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 3.0);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 3.0);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 3.0);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 2.5);
}
BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
@@ -351,11 +365,11 @@
// We can't use % tolerance here because percent error increases as the actual
// value gets very small. So, instead, we just ensure that every value is no
- // more than 0.015 away from the actual value.
+ // more than 0.009 away from the actual value.
for (size_t row = 0; row < hmm.Transition().n_rows; row++)
for (size_t col = 0; col < hmm.Transition().n_cols; col++)
BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
- 0.015);
+ 0.009);
for (size_t col = 0; col < hmm.Emission().size(); col++)
{
@@ -365,7 +379,7 @@
arma::vec obs(1);
obs[0] = row;
BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
- emission[col].Probability(obs), 0.015);
+ emission[col].Probability(obs), 0.009);
}
}
}
@@ -709,7 +723,7 @@
for (size_t row = 0; row < 3; row++)
for (size_t col = 0; col < 3; col++)
BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - hmm2.Transition()(row,
- col), 0.032);
+ col), 0.03);
// Check that each Gaussian is the same.
for (size_t em = 0; em < 3; em++)
More information about the mlpack-svn
mailing list