[mlpack-svn] r12066 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Mar 27 12:35:52 EDT 2012


Author: rcurtin
Date: 2012-03-27 12:35:52 -0400 (Tue, 27 Mar 2012)
New Revision: 12066

Modified:
   mlpack/trunk/src/mlpack/tests/gmm_test.cpp
Log:
Test the new Classify() method.


Modified: mlpack/trunk/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2012-03-27 16:35:37 UTC (rev 12065)
+++ mlpack/trunk/src/mlpack/tests/gmm_test.cpp	2012-03-27 16:35:52 UTC (rev 12066)
@@ -513,4 +513,55 @@
   BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 1),
       gmm2.Covariances()[sortedIndices[1]](1, 1), 13.0);
 }
+
+/**
+ * Test classification of observations by component.
+ */
+BOOST_AUTO_TEST_CASE(GMMClassifyTest)
+{
+  // First create a Gaussian with a few components.
+  GMM gmm(3, 2);
+  gmm.Means()[0] = "0 0";
+  gmm.Means()[1] = "1 3";
+  gmm.Means()[2] = "-2 -2";
+  gmm.Covariances()[0] = "1 0; 0 1";
+  gmm.Covariances()[1] = "3 2; 2 3";
+  gmm.Covariances()[2] = "2.2 1.4; 1.4 5.1";
+  gmm.Weights() = "0.6 0.25 0.15";
+
+  arma::mat observations = arma::trans(arma::mat(
+    " 0  0;"
+    " 0  1;"
+    " 0  2;"
+    " 1 -2;"
+    " 2 -2;"
+    "-2  0;"
+    " 5  5;"
+    "-2 -2;"
+    " 3  3;"
+    "25 25;"
+    "-1 -1;"
+    "-3 -3;"
+    "-5  1"));
+
+  arma::Col<size_t> classes;
+
+  gmm.Classify(observations, classes);
+
+  // Test classification of points.  Classifications produced by hand.
+  BOOST_REQUIRE_EQUAL(classes[ 0], 0);
+  BOOST_REQUIRE_EQUAL(classes[ 1], 0);
+  BOOST_REQUIRE_EQUAL(classes[ 2], 1);
+  BOOST_REQUIRE_EQUAL(classes[ 3], 0);
+  BOOST_REQUIRE_EQUAL(classes[ 4], 0);
+  BOOST_REQUIRE_EQUAL(classes[ 5], 0);
+  BOOST_REQUIRE_EQUAL(classes[ 6], 1);
+  BOOST_REQUIRE_EQUAL(classes[ 7], 2);
+  BOOST_REQUIRE_EQUAL(classes[ 8], 1);
+  BOOST_REQUIRE_EQUAL(classes[ 9], 1);
+  BOOST_REQUIRE_EQUAL(classes[10], 0);
+  BOOST_REQUIRE_EQUAL(classes[11], 2);
+  BOOST_REQUIRE_EQUAL(classes[12], 2);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list