[mlpack-svn] r10533 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 3 19:08:58 EST 2011


Author: rcurtin
Date: 2011-12-03 19:08:58 -0500 (Sat, 03 Dec 2011)
New Revision: 10533

Modified:
   mlpack/trunk/src/mlpack/tests/gmm_test.cpp
Log:
First simple test for GMM::Estimate(const arma::mat&, const arma::vec&).


Modified: mlpack/trunk/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2011-12-04 00:08:47 UTC (rev 10532)
+++ mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2011-12-04 00:08:58 UTC (rev 10533)
@@ -8,6 +8,7 @@
 
 #include <mlpack/methods/gmm/gmm.hpp>
 #include <mlpack/methods/gmm/phi.hpp>
+#include <mlpack/methods/hmm/distributions/gaussian_distribution.hpp>
 
 #include <boost/test/unit_test.hpp>
 
@@ -145,7 +146,7 @@
 
     // Now, train the model.
     GMM gmm(1, 2);
-    gmm.ExpectationMaximization(data);
+    gmm.Estimate(data);
 
     arma::vec actual_mean = arma::mean(data, 1);
     arma::mat actual_covar = ccov(data, 1 /* biased estimator */);
@@ -235,7 +236,7 @@
 
   // Now train the model.
   GMM gmm(gaussians, dims);
-  gmm.ExpectationMaximization(data);
+  gmm.Estimate(data);
 
   arma::uvec sort_ref = sort_index(weights);
   arma::uvec sort_try = sort_index(gmm.Weights());
@@ -260,4 +261,41 @@
   }
 }
 
+/**
+ * Train a single-gaussian mixture, but using the overload of Estimate() where
+ * probabilities of the observation are given.
+ */
+BOOST_AUTO_TEST_CASE(GMMEMTrainSingleGaussianWithProbability)
+{
+  srand(time(NULL));
+
+  // Generate observations from a Gaussian distribution.
+  distribution::GaussianDistribution d("0.5 1.0", "1.0 0.3; 0.3 1.0");
+
+  // 10000 observations, each with random probability.
+  arma::mat observations(2, 20000);
+  for (size_t i = 0; i < 20000; i++)
+    observations.col(i) = d.Random();
+  arma::vec probabilities;
+  probabilities.randu(20000); // Random probabilities.
+
+  // Now train the model.
+  GMM g(1, 2);
+
+  g.Estimate(observations, probabilities);
+
+  // Check that it is trained correctly.  5% tolerance because of random error
+  // present in observations.
+  BOOST_REQUIRE_CLOSE(g.Means()[0][0], 0.5, 5.0);
+  BOOST_REQUIRE_CLOSE(g.Means()[0][1], 1.0, 5.0);
+
+  // 7% tolerance on the large numbers, 10% on the smaller numbers.
+  BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 0), 1.0, 7.0);
+  BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 1), 0.3, 10.0);
+  BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 0), 0.3, 10.0);
+  BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 1), 1.0, 7.0);
+
+  BOOST_REQUIRE_CLOSE(g.Weights()[0], 1.0, 1e-5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list