[mlpack-svn] r10576 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 03:59:24 EST 2011
Author: rcurtin
Date: 2011-12-06 03:59:23 -0500 (Tue, 06 Dec 2011)
New Revision: 10576
Modified:
mlpack/trunk/src/mlpack/tests/gmm_test.cpp
Log:
Test GMM::Random().
Modified: mlpack/trunk/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/gmm_test.cpp 2011-12-06 08:58:59 UTC (rev 10575)
+++ mlpack/trunk/src/mlpack/tests/gmm_test.cpp 2011-12-06 08:59:23 UTC (rev 10576)
@@ -396,4 +396,72 @@
d3.Covariance()(row, col)), 0.50); // Big tolerance! Lots of noise.
}
+/**
+ * Make sure generating observations randomly works. We'll do this by
+ * generating a bunch of random observations and then re-training on them, and
+ * hope that our model is the same.
+ */
+BOOST_AUTO_TEST_CASE(GMMRandomTest)
+{
+ srand(time(NULL));
+
+ // Simple GMM distribution.
+ GMM gmm(2, 2);
+ gmm.Weights() = arma::vec("0.40 0.60");
+
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmm.Means()[0] = arma::vec("2.25 3.10");
+ gmm.Covariances()[0] = arma::mat("1.00 0.60; 0.60 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmm.Means()[1] = arma::vec("4.10 1.01");
+ gmm.Covariances()[1] = arma::mat("1.00 0.70; 0.70 1.01");
+
+ // Now generate a bunch of observations.
+ arma::mat observations(2, 4000);
+ for (size_t i = 0; i < 4000; i++)
+ observations.col(i) = gmm.Random();
+
+ // A new one which we'll train.
+ GMM gmm2(2, 2);
+ gmm2.Estimate(observations);
+
+ // Now check the results. We need to order by weights so that when we do the
+ // checking, things will be correct.
+ arma::uvec sortedIndices = sort_index(gmm2.Weights());
+
+ // Now check that the parameters are the same. Tolerances are kind of big
+ // because we only used 2000 observations.
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[0], gmm2.Weights()[sortedIndices[0]], 7.0);
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[1], gmm2.Weights()[sortedIndices[1]], 7.0);
+
+ BOOST_REQUIRE_CLOSE(gmm.Means()[0][0], gmm2.Means()[sortedIndices[0]][0],
+ 6.5);
+ BOOST_REQUIRE_CLOSE(gmm.Means()[0][1], gmm2.Means()[sortedIndices[0]][1],
+ 6.5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 0),
+ gmm2.Covariances()[sortedIndices[0]](0, 0), 13.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 1),
+ gmm2.Covariances()[sortedIndices[0]](0, 1), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 0),
+ gmm2.Covariances()[sortedIndices[0]](1, 0), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 1),
+ gmm2.Covariances()[sortedIndices[0]](1, 1), 13.0);
+
+ BOOST_REQUIRE_CLOSE(gmm.Means()[1][0], gmm2.Means()[sortedIndices[1]][0],
+ 6.5);
+ BOOST_REQUIRE_CLOSE(gmm.Means()[1][1], gmm2.Means()[sortedIndices[1]][1],
+ 6.5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 0),
+ gmm2.Covariances()[sortedIndices[1]](0, 0), 13.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 1),
+ gmm2.Covariances()[sortedIndices[1]](0, 1), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 0),
+ gmm2.Covariances()[sortedIndices[1]](1, 0), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 1),
+ gmm2.Covariances()[sortedIndices[1]](1, 1), 13.0);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list