[mlpack-git] master: Add another test to mean shift, which should cluster 4 Gaussians. (a96b462)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jun 17 17:02:30 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee7c82dba945db7c5469485c61d626eb0a4629b0...98c0c483a3547c8f49cdfe38670a603bd29036a0
>---------------------------------------------------------------
commit a96b4629cb0fce4b3ab02af6b72f04d6c80c5011
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jun 17 16:53:52 2015 -0400
Add another test to mean shift, which should cluster 4 Gaussians.
>---------------------------------------------------------------
a96b4629cb0fce4b3ab02af6b72f04d6c80c5011
src/mlpack/tests/mean_shift_test.cpp | 58 ++++++++++++++++++++++++++++++++++++
1 file changed, 58 insertions(+)
diff --git a/src/mlpack/tests/mean_shift_test.cpp b/src/mlpack/tests/mean_shift_test.cpp
index f5c071a..e777e15 100644
--- a/src/mlpack/tests/mean_shift_test.cpp
+++ b/src/mlpack/tests/mean_shift_test.cpp
@@ -12,6 +12,7 @@
using namespace mlpack;
using namespace mlpack::meanshift;
+using namespace mlpack::distribution;
BOOST_AUTO_TEST_SUITE(MeanShiftTest);
@@ -85,4 +86,61 @@ BOOST_AUTO_TEST_CASE(MeanShiftSimpleTest) {
}
+// Generate samples from four Gaussians, and make sure mean shift nearly
+// recovers those four centers.
+BOOST_AUTO_TEST_CASE(GaussianClustering)
+{
+ math::RandomSeed(std::time(NULL));
+ GaussianDistribution g1("0.0 0.0 0.0", arma::eye<arma::mat>(3, 3));
+ GaussianDistribution g2("5.0 5.0 5.0", 2 * arma::eye<arma::mat>(3, 3));
+ GaussianDistribution g3("-3.0 3.0 -1.0", arma::eye<arma::mat>(3, 3));
+ GaussianDistribution g4("6.0 -2.0 -2.0", 3 * arma::eye<arma::mat>(3, 3));
+
+ arma::mat dataset(3, 4000);
+ for (size_t i = 0; i < 1000; ++i)
+ dataset.col(i) = g1.Random();
+ for (size_t i = 1000; i < 2000; ++i)
+ dataset.col(i) = g2.Random();
+ for (size_t i = 2000; i < 3000; ++i)
+ dataset.col(i) = g3.Random();
+ for (size_t i = 3000; i < 4000; ++i)
+ dataset.col(i) = g4.Random();
+
+ // Now that the dataset is generated, run mean shift. Pre-set radius.
+ MeanShift<> meanShift(2.9);
+
+ arma::Col<size_t> assignments;
+ arma::mat centroids;
+ meanShift.Cluster(dataset, assignments, centroids);
+
+ BOOST_REQUIRE_EQUAL(centroids.n_cols, 4);
+ BOOST_REQUIRE_EQUAL(centroids.n_rows, 3);
+
+ std::cout << centroids.t();
+
+ // Check that each centroid is close to only one mean.
+ arma::vec centroidDistances(4);
+ arma::uvec minIndices(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ centroidDistances(0) = metric::EuclideanDistance::Evaluate(g1.Mean(),
+ centroids.col(i));
+ centroidDistances(1) = metric::EuclideanDistance::Evaluate(g2.Mean(),
+ centroids.col(i));
+ centroidDistances(2) = metric::EuclideanDistance::Evaluate(g3.Mean(),
+ centroids.col(i));
+ centroidDistances(3) = metric::EuclideanDistance::Evaluate(g4.Mean(),
+ centroids.col(i));
+
+ // Are we near a centroid of a Gaussian?
+ const double minVal = centroidDistances.min(minIndices[i]);
+ BOOST_REQUIRE_SMALL(minVal, 0.65); // A decent amount of tolerance...
+ }
+
+ // Ensure each centroid corresponds to a different Gaussian.
+ for (size_t i = 0; i < 4; ++i)
+ for (size_t j = i + 1; j < 4; ++j)
+ BOOST_REQUIRE_NE(minIndices[i], minIndices[j]);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list