[mlpack-git] master: Fix changed API. (31d0a65)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:55:28 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 31d0a6537453a5f866d19ac81550f1a1b02d95d7
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jul 30 16:58:23 2014 +0000
Fix changed API.
>---------------------------------------------------------------
31d0a6537453a5f866d19ac81550f1a1b02d95d7
src/mlpack/tests/kmeans_test.cpp | 34 +++++++++++++++++++++++++++++-----
1 file changed, 29 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index a79630a..0853a0d 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -137,8 +137,9 @@ BOOST_AUTO_TEST_CASE(AllowEmptyClusterTest)
arma::Col<size_t> countsOld = counts;
// Make sure the method doesn't modify any points.
+ metric::LMetric<2, true> metric;
BOOST_REQUIRE_EQUAL(AllowEmptyClusters::EmptyCluster(kMeansData, 2, centroids,
- counts, assignments), 0);
+ counts, metric), 0);
// Make sure no assignments were changed.
for (size_t i = 0; i < assignments.n_elem; i++)
@@ -164,16 +165,39 @@ BOOST_AUTO_TEST_CASE(MaxVarianceNewClusterTest)
arma::mat centroids(2, 3);
centroids.col(0) = (1.0 / 3.0) * (data.col(0) + data.col(1) + data.col(2));
centroids.col(1) = 0.5 * (data.col(3) + data.col(4));
- centroids(0, 2) = 0;
- centroids(1, 2) = 0;
+ centroids(0, 2) = DBL_MAX;
+ centroids(1, 2) = DBL_MAX;
arma::Col<size_t> counts("3 2 0");
+ metric::LMetric<2, true> metric;
+
// This should only change one point.
BOOST_REQUIRE_EQUAL(MaxVarianceNewCluster::EmptyCluster(data, 2, centroids,
- counts, assignments), 1);
+ counts, metric), 1);
+
+ // Add the variance of each point's distance away from the cluster. I think
+ // this is the sensible thing to do.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Find the closest centroid to this point.
+ double minDistance = std::numeric_limits<double>::infinity();
+ size_t closestCluster = centroids.n_cols; // Invalid value.
+
+ for (size_t j = 0; j < centroids.n_cols; j++)
+ {
+ const double distance = metric.Evaluate(data.col(i), centroids.col(j));
+
+ if (distance < minDistance)
+ {
+ minDistance = distance;
+ closestCluster = j;
+ }
+ }
+
+ assignments[i] = closestCluster;
+ }
- // Ensure that the cluster assignments are right.
BOOST_REQUIRE_EQUAL(assignments[0], 0);
BOOST_REQUIRE_EQUAL(assignments[1], 0);
BOOST_REQUIRE_EQUAL(assignments[2], 2);
More information about the mlpack-git
mailing list