[mlpack-svn] r10576 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 03:59:24 EST 2011


Author: rcurtin
Date: 2011-12-06 03:59:23 -0500 (Tue, 06 Dec 2011)
New Revision: 10576

Modified:
   mlpack/trunk/src/mlpack/tests/gmm_test.cpp
Log:
Test GMM::Random().


Modified: mlpack/trunk/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2011-12-06 08:58:59 UTC (rev 10575)
+++ mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2011-12-06 08:59:23 UTC (rev 10576)
@@ -396,4 +396,72 @@
           d3.Covariance()(row, col)), 0.50); // Big tolerance!  Lots of noise.
 }
 
+/**
+ * Make sure generating observations randomly works.  We'll do this by
+ * generating a bunch of random observations and then re-training on them, and
+ * hope that our model is the same.
+ */
+BOOST_AUTO_TEST_CASE(GMMRandomTest)
+{
+  srand(time(NULL));
+
+  // Simple GMM distribution.
+  GMM gmm(2, 2);
+  gmm.Weights() = arma::vec("0.40 0.60");
+
+  // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+  gmm.Means()[0] = arma::vec("2.25 3.10");
+  gmm.Covariances()[0] = arma::mat("1.00 0.60; 0.60 0.89");
+
+  // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+  gmm.Means()[1] = arma::vec("4.10 1.01");
+  gmm.Covariances()[1] = arma::mat("1.00 0.70; 0.70 1.01");
+
+  // Now generate a bunch of observations.
+  arma::mat observations(2, 4000);
+  for (size_t i = 0; i < 4000; i++)
+    observations.col(i) = gmm.Random();
+
+  // A new one which we'll train.
+  GMM gmm2(2, 2);
+  gmm2.Estimate(observations);
+
+  // Now check the results.  We need to order by weights so that when we do the
+  // checking, things will be correct.
+  arma::uvec sortedIndices = sort_index(gmm2.Weights());
+
+  // Now check that the parameters are the same.  Tolerances are kind of big
+  // because we only used 2000 observations.
+  BOOST_REQUIRE_CLOSE(gmm.Weights()[0], gmm2.Weights()[sortedIndices[0]], 7.0);
+  BOOST_REQUIRE_CLOSE(gmm.Weights()[1], gmm2.Weights()[sortedIndices[1]], 7.0);
+
+  BOOST_REQUIRE_CLOSE(gmm.Means()[0][0], gmm2.Means()[sortedIndices[0]][0],
+      6.5);
+  BOOST_REQUIRE_CLOSE(gmm.Means()[0][1], gmm2.Means()[sortedIndices[0]][1],
+      6.5);
+
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 0),
+      gmm2.Covariances()[sortedIndices[0]](0, 0), 13.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 1),
+      gmm2.Covariances()[sortedIndices[0]](0, 1), 22.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 0),
+      gmm2.Covariances()[sortedIndices[0]](1, 0), 22.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 1),
+      gmm2.Covariances()[sortedIndices[0]](1, 1), 13.0);
+
+  BOOST_REQUIRE_CLOSE(gmm.Means()[1][0], gmm2.Means()[sortedIndices[1]][0],
+      6.5);
+  BOOST_REQUIRE_CLOSE(gmm.Means()[1][1], gmm2.Means()[sortedIndices[1]][1],
+      6.5);
+
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 0),
+      gmm2.Covariances()[sortedIndices[1]](0, 0), 13.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 1),
+      gmm2.Covariances()[sortedIndices[1]](0, 1), 22.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 0),
+      gmm2.Covariances()[sortedIndices[1]](1, 0), 22.0);
+  BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 1),
+      gmm2.Covariances()[sortedIndices[1]](1, 1), 13.0);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list