[mlpack-git] master: only consider points within radius (6e1e8a9)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:28 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 6e1e8a91127408f7238e7e25112e6926e164f394
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Sat Jan 31 20:38:34 2015 +0800
only consider points within radius
>---------------------------------------------------------------
6e1e8a91127408f7238e7e25112e6926e164f394
src/mlpack/core/kernels/gaussian_kernel.hpp | 4 ++
src/mlpack/methods/mean_shift/mean_shift.hpp | 55 ++++++----------
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 79 +++++++++++++----------
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 16 ++---
src/mlpack/tests/mean_shift_test.cpp | 19 ------
5 files changed, 74 insertions(+), 99 deletions(-)
diff --git a/src/mlpack/core/kernels/gaussian_kernel.hpp b/src/mlpack/core/kernels/gaussian_kernel.hpp
index 9e52638..22afb5e 100644
--- a/src/mlpack/core/kernels/gaussian_kernel.hpp
+++ b/src/mlpack/core/kernels/gaussian_kernel.hpp
@@ -75,6 +75,10 @@ class GaussianKernel
return exp(gamma * std::pow(t, 2.0));
}
+ double Gradient(const double t) const {
+ return gamma * exp(gamma * std::pow(t, 2.0));
+ }
+
/**
* Obtain the normalization constant of the Gaussian kernel.
*
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 6a0c809..c45a759 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -28,7 +28,7 @@ namespace meanshift /** Mean Shift clustering. */ {
* arma::Col<size_t> assignments; // Cluster assignments.
* arma::mat centroids; // Cluster centroids.
*
- * MeanShift<arma::mat, kernel::GaussianKernel> meanShift();
+ * MeanShift<> meanShift();
* meanShift.Cluster(dataset, assignments, centroids);
* @endcode
*
@@ -36,7 +36,6 @@ namespace meanshift /** Mean Shift clustering. */ {
*/
template<typename KernelType = kernel::GaussianKernel,
- typename MetricType = metric::EuclideanDistance,
typename MatType = arma::mat>
class MeanShift
{
@@ -45,25 +44,21 @@ class MeanShift
* Create a Mean Shift object and set the parameters which Mean Shift
* will be run with.
*
- * @param duplicateThresh If distance of two centroids is less than it, one will be removed. If this value is negative, an estimation will be given when clustering.
+ * @param radius If distance of two centroids is less than it, one will be removed. If this value isn't positive, an estimation will be given when clustering.
* @param maxIterations Maximum number of iterations allowed before giving up
- * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh,
* iterations will terminate.
* @param kernel Optional KernelType object.
- * @param metric Optional the metric to calculate distance.
*/
- MeanShift(const double duplicateThresh = -1,
+ MeanShift(const double radius = 0,
const size_t maxIterations = 1000,
- const double stopThresh = 1e-3,
- const KernelType kernel = KernelType(),
- const MetricType metric = MetricType());
+ const KernelType kernel = KernelType());
/**
- * Give an estimation of duplicate thresh based on given dataset.
+ * Give an estimation of radius based on given dataset.
* @param data Dataset for estimation.
*/
- double estimateDuplicateThresh(const MatType& data);
+ double estimateRadius(const MatType& data);
/**
* Perform Mean Shift clustering on the data, returning a list of cluster
@@ -83,46 +78,34 @@ class MeanShift
//! Set the maximum number of iterations.
size_t& MaxIterations() { return maxIterations; }
- //! Get the stop thresh.
- double StopThresh() const { return stopThresh; }
- //! Set the stop thresh.
- double& StopThresh() { return stopThresh; }
+ //! Get the radius.
+ double Radius() const { return radius; }
+ //! Set the radius.
+ void Radius(double radius);
//! Get the kernel.
const KernelType& Kernel() const { return kernel; }
//! Modify the kernel.
KernelType& Kernel() { return kernel; }
- //! Get the metric.
- const MetricType& Metric() const { return metric; }
- //! Modify the metric.
- MetricType& Metric() { return metric; }
-
- //! Get the duplicate thresh.
- double DuplicateThresh() const { return duplicateThresh; }
- //! Set the duplicate thresh.
- double& DuplicateThresh() { return duplicateThresh; }
-
private:
- // If distance of two centroids is less than duplicateThresh, one will be removed.
- double duplicateThresh;
+ /**
+ * If distance of two centroids is less than radius, one will be removed.
+ * Points with distance to current centroid less than radius will be used
+ * to calculate new centroid.
+ */
+ double radius;
+
+ // By storing radius * radius, we can speed up a little.
+ double squaredRadius;
//! Maximum number of iterations before giving up.
size_t maxIterations;
- /**
- * If the 2-norm of the mean shift vector is less than stopThresh,
- * iterations will terminate.
- */
- double stopThresh;
-
//! Instantiated kernel.
KernelType kernel;
- //! Instantiated metric.
- MetricType metric;
-
};
}; // namespace meanshift
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 09a54a5..86fb811 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -20,40 +20,40 @@ namespace meanshift {
* Construct the Mean Shift object.
*/
template<typename KernelType,
- typename MetricType,
typename MatType>
MeanShift<
KernelType,
- MetricType,
MatType>::
-MeanShift(const double duplicateThresh,
+MeanShift(const double radius,
const size_t maxIterations,
- const double stopThresh,
- const KernelType kernel,
- const MetricType metric) :
- duplicateThresh(duplicateThresh),
+ const KernelType kernel) :
maxIterations(maxIterations),
- stopThresh(stopThresh),
- kernel(kernel),
- metric(metric)
+ kernel(kernel)
{
- // Nothing to do.
+ Radius(radius);
}
+template<typename KernelType,
+ typename MatType>
+void MeanShift<
+ KernelType,
+ MatType>::
+Radius(double radius) {
+ this->radius = radius;
+ squaredRadius = radius * radius;
+}
-// Estimate duplicate thresh based on given dataset.
+// Estimate radius based on given dataset.
template<typename KernelType,
- typename MetricType,
typename MatType>
double MeanShift<
KernelType,
- MetricType,
MatType>::
-estimateDuplicateThresh(const MatType &data) {
+estimateRadius(const MatType &data) {
neighbor::NeighborSearch<
neighbor::NearestNeighborSort,
- MetricType,
+ metric::EuclideanDistance,
tree::BinarySpaceTree<bound::HRectBound<2>,
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
> neighborSearch(data);
@@ -76,27 +76,24 @@ estimateDuplicateThresh(const MatType &data) {
}
- /**
- * Perform Mean Shift clustering on the data, returning a list of cluster
- * assignments and centroids.
- */
+/**
+ * Perform Mean Shift clustering on the data, returning a list of cluster
+ * assignments and centroids.
+ */
template<typename KernelType,
- typename MetricType,
typename MatType>
inline void MeanShift<
KernelType,
- MetricType,
MatType>::
Cluster(const MatType& data,
arma::Col<size_t>& assignments,
arma::mat& centroids) {
- if (duplicateThresh < 0) {
- // An invalid duplicate thresh is given, an estimation is needed.
- duplicateThresh = estimateDuplicateThresh(data);
+ if (radius <= 0) {
+ // An invalid radius is given, an estimation is needed.
+ Radius(estimateRadius(data));
}
-
// all centroids before remove duplicate ones.
arma::mat allCentroids(data.n_rows, data.n_cols);
assignments.set_size(data.n_cols);
@@ -110,19 +107,30 @@ Cluster(const MatType& data,
for (size_t completedIterations = 0; completedIterations < maxIterations;
completedIterations++) {
- // new centroid
+ // to store new centroid
arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
double sumWeight = 0;
+
+ // Go through all the points
for (size_t j = 0; j < data.n_cols; ++j) {
- // calc weight for each point
- double weight = kernel.Evaluate(metric.Evaluate(allCentroids.col(i),
- data.col(j)));
- sumWeight += weight;
+ // Calculate the distance between old centroid and current point.
+ double squaredDist = metric::SquaredEuclideanDistance::
+ Evaluate(allCentroids.col(i), data.col(j));
+
+ // If current point is near the old centroid
+ if (squaredDist < squaredRadius) {
+
+ // calculate weight for current point
+ double weight = kernel.Gradient(squaredDist / squaredRadius);
- // update new centroid.
- newCentroid += weight * data.col(j);
+ sumWeight += weight;
+
+ // update new centroid.
+ newCentroid += weight * data.col(j);
+
+ }
}
@@ -134,13 +142,14 @@ Cluster(const MatType& data,
// update the centroid.
allCentroids.col(i) = newCentroid;
- if (arma::norm(mhVector, 2) < stopThresh) {
+ // If the 2-norm of mean shift vector is small enough, it has converged.
+ if (arma::norm(mhVector, 2) < 1e-3 * radius) {
// Determine if the new centroid is duplicate with old ones.
bool isDuplicated = false;
for (size_t k = 0; k < centroids.n_cols; ++k) {
arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
- if (norm(delta, 2) < duplicateThresh) {
+ if (norm(delta, 2) < radius) {
isDuplicated = true;
assignments(i) = k;
break;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index d25970f..c53d29f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -35,23 +35,22 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
// Mean Shift configuration options.
PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
"terminates.", "m", 1000);
-PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
- "is less than stopThresh, iterations will terminate. ", "s", 1e-3);
+
PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
-PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
+PARAM_DOUBLE("radius", "If distance of two centroids is less than radius "
"one will be removed. "
- "If it's negative, an estimation will be given. ", "d", -1.0);
-
+ "If it isn't positive, an estimation will be given. "
+ "Points with distance to current centroid less than radius "
+ "will be used to calculate new centroid. ", "r", 0);
int main(int argc, char** argv) {
CLI::ParseCommandLine(argc, argv);
const string inputFile = CLI::GetParam<string>("inputFile");
- const double stopThresh = CLI::GetParam<double>("stopThresh");
+ const double radius = CLI::GetParam<double>("radius");
const double bandwidth = CLI::GetParam<double>("bandwidth");
const int maxIterations = CLI::GetParam<int>("max_iterations");
- const double duplicateThresh = CLI::GetParam<double>("duplicateThresh");
if (maxIterations < 0) {
Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
@@ -72,8 +71,7 @@ int main(int argc, char** argv) {
kernel::GaussianKernel kernel(bandwidth);
- MeanShift<> meanShift(duplicateThresh,
- maxIterations, stopThresh, kernel);
+ MeanShift<> meanShift(radius, maxIterations, kernel);
Timer::Start("clustering");
meanShift.Cluster(dataset, assignments, centroids);
Timer::Stop("clustering");
diff --git a/src/mlpack/tests/mean_shift_test.cpp b/src/mlpack/tests/mean_shift_test.cpp
index 6088310..f5c071a 100644
--- a/src/mlpack/tests/mean_shift_test.cpp
+++ b/src/mlpack/tests/mean_shift_test.cpp
@@ -85,23 +85,4 @@ BOOST_AUTO_TEST_CASE(MeanShiftSimpleTest) {
}
-/**
- * When duplicate thresh is set to 0, any centroids shouldn't removed.
- */
-BOOST_AUTO_TEST_CASE(ZeroDuplicateThreshTest) {
-
- // Set the duplicate thresh to 0
- MeanShift<> meanShift(0);
-
- arma::Col<size_t> assignments;
- arma::mat centroids;
- meanShift.Cluster((arma::mat) trans(meanShiftData), assignments, centroids);
-
- /**
- * Make sure the number of centroids is equal to
- * the number of vectors in dataset.
- */
- BOOST_REQUIRE_EQUAL(centroids.n_cols, meanShiftData.n_rows);
-}
-
BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file
More information about the mlpack-git
mailing list