[mlpack-git] master: Add another test to mean shift, which should cluster 4 Gaussians. (a96b462)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jun 17 17:02:30 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee7c82dba945db7c5469485c61d626eb0a4629b0...98c0c483a3547c8f49cdfe38670a603bd29036a0

>---------------------------------------------------------------

commit a96b4629cb0fce4b3ab02af6b72f04d6c80c5011
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jun 17 16:53:52 2015 -0400

    Add another test to mean shift, which should cluster 4 Gaussians.


>---------------------------------------------------------------

a96b4629cb0fce4b3ab02af6b72f04d6c80c5011
 src/mlpack/tests/mean_shift_test.cpp | 58 ++++++++++++++++++++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/src/mlpack/tests/mean_shift_test.cpp b/src/mlpack/tests/mean_shift_test.cpp
index f5c071a..e777e15 100644
--- a/src/mlpack/tests/mean_shift_test.cpp
+++ b/src/mlpack/tests/mean_shift_test.cpp
@@ -12,6 +12,7 @@
 
 using namespace mlpack;
 using namespace mlpack::meanshift;
+using namespace mlpack::distribution;
 
 BOOST_AUTO_TEST_SUITE(MeanShiftTest);
 
@@ -85,4 +86,61 @@ BOOST_AUTO_TEST_CASE(MeanShiftSimpleTest) {
 
 }
 
+// Generate samples from four Gaussians, and make sure mean shift nearly
+// recovers those four centers.
+BOOST_AUTO_TEST_CASE(GaussianClustering)
+{
+  math::RandomSeed(std::time(NULL));
+  GaussianDistribution g1("0.0 0.0 0.0", arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g2("5.0 5.0 5.0", 2 * arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g3("-3.0 3.0 -1.0", arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g4("6.0 -2.0 -2.0", 3 * arma::eye<arma::mat>(3, 3));
+
+  arma::mat dataset(3, 4000);
+  for (size_t i = 0; i < 1000; ++i)
+    dataset.col(i) = g1.Random();
+  for (size_t i = 1000; i < 2000; ++i)
+    dataset.col(i) = g2.Random();
+  for (size_t i = 2000; i < 3000; ++i)
+    dataset.col(i) = g3.Random();
+  for (size_t i = 3000; i < 4000; ++i)
+    dataset.col(i) = g4.Random();
+
+  // Now that the dataset is generated, run mean shift.  Pre-set radius.
+  MeanShift<> meanShift(2.9);
+
+  arma::Col<size_t> assignments;
+  arma::mat centroids;
+  meanShift.Cluster(dataset, assignments, centroids);
+
+  BOOST_REQUIRE_EQUAL(centroids.n_cols, 4);
+  BOOST_REQUIRE_EQUAL(centroids.n_rows, 3);
+
+  std::cout << centroids.t();
+
+  // Check that each centroid is close to only one mean.
+  arma::vec centroidDistances(4);
+  arma::uvec minIndices(4);
+  for (size_t i = 0; i < 4; ++i)
+  {
+    centroidDistances(0) = metric::EuclideanDistance::Evaluate(g1.Mean(),
+        centroids.col(i));
+    centroidDistances(1) = metric::EuclideanDistance::Evaluate(g2.Mean(),
+        centroids.col(i));
+    centroidDistances(2) = metric::EuclideanDistance::Evaluate(g3.Mean(),
+        centroids.col(i));
+    centroidDistances(3) = metric::EuclideanDistance::Evaluate(g4.Mean(),
+        centroids.col(i));
+
+    // Are we near a centroid of a Gaussian?
+    const double minVal = centroidDistances.min(minIndices[i]);
+    BOOST_REQUIRE_SMALL(minVal, 0.65); // A decent amount of tolerance...
+  }
+
+  // Ensure each centroid corresponds to a different Gaussian.
+  for (size_t i = 0; i < 4; ++i)
+    for (size_t j = i + 1; j < 4; ++j)
+      BOOST_REQUIRE_NE(minIndices[i], minIndices[j]);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list