[mlpack-svn] r10533 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 3 19:08:58 EST 2011
Author: rcurtin
Date: 2011-12-03 19:08:58 -0500 (Sat, 03 Dec 2011)
New Revision: 10533
Modified:
mlpack/trunk/src/mlpack/tests/gmm_test.cpp
Log:
First simple test for GMM::Estimate(const arma::mat&, const arma::vec&).
Modified: mlpack/trunk/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/gmm_test.cpp 2011-12-04 00:08:47 UTC (rev 10532)
+++ mlpack/trunk/src/mlpack/tests/gmm_test.cpp 2011-12-04 00:08:58 UTC (rev 10533)
@@ -8,6 +8,7 @@
#include <mlpack/methods/gmm/gmm.hpp>
#include <mlpack/methods/gmm/phi.hpp>
+#include <mlpack/methods/hmm/distributions/gaussian_distribution.hpp>
#include <boost/test/unit_test.hpp>
@@ -145,7 +146,7 @@
// Now, train the model.
GMM gmm(1, 2);
- gmm.ExpectationMaximization(data);
+ gmm.Estimate(data);
arma::vec actual_mean = arma::mean(data, 1);
arma::mat actual_covar = ccov(data, 1 /* biased estimator */);
@@ -235,7 +236,7 @@
// Now train the model.
GMM gmm(gaussians, dims);
- gmm.ExpectationMaximization(data);
+ gmm.Estimate(data);
arma::uvec sort_ref = sort_index(weights);
arma::uvec sort_try = sort_index(gmm.Weights());
@@ -260,4 +261,41 @@
}
}
+/**
+ * Train a single-gaussian mixture, but using the overload of Estimate() where
+ * probabilities of the observation are given.
+ */
+BOOST_AUTO_TEST_CASE(GMMEMTrainSingleGaussianWithProbability)
+{
+ srand(time(NULL));
+
+ // Generate observations from a Gaussian distribution.
+ distribution::GaussianDistribution d("0.5 1.0", "1.0 0.3; 0.3 1.0");
+
+ // 10000 observations, each with random probability.
+ arma::mat observations(2, 20000);
+ for (size_t i = 0; i < 20000; i++)
+ observations.col(i) = d.Random();
+ arma::vec probabilities;
+ probabilities.randu(20000); // Random probabilities.
+
+ // Now train the model.
+ GMM g(1, 2);
+
+ g.Estimate(observations, probabilities);
+
+ // Check that it is trained correctly. 5% tolerance because of random error
+ // present in observations.
+ BOOST_REQUIRE_CLOSE(g.Means()[0][0], 0.5, 5.0);
+ BOOST_REQUIRE_CLOSE(g.Means()[0][1], 1.0, 5.0);
+
+ // 7% tolerance on the large numbers, 10% on the smaller numbers.
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 0), 1.0, 7.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 1), 0.3, 10.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 0), 0.3, 10.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 1), 1.0, 7.0);
+
+ BOOST_REQUIRE_CLOSE(g.Weights()[0], 1.0, 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list