[mlpack-git] master: estimation of duplicate thresh (1ea40f2)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:10 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 1ea40f22f08840fbf65f35e139cfa5fd301b5a73
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Tue Jan 27 22:00:04 2015 +0800
estimation of duplicate thresh
>---------------------------------------------------------------
1ea40f22f08840fbf65f35e139cfa5fd301b5a73
src/mlpack/methods/mean_shift/mean_shift.hpp | 10 ++++-
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 55 ++++++++++++++++++++---
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 3 +-
3 files changed, 59 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 09a27d4..d866b82 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -43,19 +43,25 @@ class MeanShift
* Create a Mean Shift object and set the parameters which Mean Shift
* will be run with.
*
- * @param duplicateThresh If distance of two centroids is less than it, one will be removed.
+ * @param duplicateThresh If distance of two centroids is less than it, one will be removed. If this value is negative, an estimation will be given when clustering.
* @param maxIterations Maximum number of iterations allowed before giving up
* @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh,
* iterations will terminate.
* @param kernel Optional KernelType object.
*/
- MeanShift(const double duplicateThresh = 1.0,
+ MeanShift(const double duplicateThresh = -1,
const size_t maxIterations = 1000,
const double stopThresh = 1e-3,
const KernelType kernel = KernelType());
/**
+ * Give an estimation of duplicate thresh based on given dataset.
+ * @param data Dataset for estimation.
+ */
+ double estimateDuplicateThresh(const MatType& data);
+
+ /**
* Perform Mean Shift clustering on the data, returning a list of cluster
* assignments and centroids.
*
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 0e4d8c7..15b5abb 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -7,6 +7,11 @@
#include "mean_shift.hpp"
#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search_stat.hpp>
+
+
namespace mlpack {
namespace meanshift {
@@ -31,7 +36,41 @@ MeanShift(const double duplicateThresh,
// Nothing to do.
}
-/**
+
+// Estimate duplicate thresh based on given dataset.
+template<typename KernelType,
+ typename MatType>
+double MeanShift<
+ KernelType,
+ MatType>::
+estimateDuplicateThresh(const MatType &data) {
+
+ neighbor::NeighborSearch<
+ neighbor::NearestNeighborSort,
+ metric::EuclideanDistance,
+ tree::BinarySpaceTree<bound::HRectBound<2>,
+ neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
+ > neighborSearch(data);
+
+ /**
+ * For each point in dataset,
+ * select nNeighbors nearest points and get nNeighbors distances.
+ * Use the maximum distance to estimate the duplicate thresh.
+ */
+ size_t nNeighbors = (int)(data.n_cols * 0.3);
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ neighborSearch.Search(nNeighbors, neighbors, distances);
+
+ // Get max distance for each point.
+ arma::rowvec maxDistances = max(distances);
+
+ // Calc and return the duplicate thresh.
+ return sum(maxDistances) / (double)data.n_cols;
+
+}
+
+ /**
* Perform Mean Shift clustering on the data, returning a list of cluster
* assignments and centroids.
*/
@@ -44,6 +83,12 @@ Cluster(const MatType& data,
arma::Col<size_t>& assignments,
arma::mat& centroids) {
+ if (duplicateThresh < 0) {
+ // An invalid duplicate thresh is given, an estimation is needed.
+ duplicateThresh = estimateDuplicateThresh(data);
+ }
+
+
// all centroids before remove duplicate ones.
arma::mat allCentroids(data.n_rows, data.n_cols);
assignments.set_size(data.n_cols);
@@ -51,12 +96,11 @@ Cluster(const MatType& data,
// for each point in dataset, perform mean shift algorithm.
for (size_t i = 0; i < data.n_cols; ++i) {
- size_t completedIterations = 0;
-
//initial centroid is the point itself.
allCentroids.col(i) = data.col(i);
- while (true) {
+ for (size_t completedIterations = 0; completedIterations < maxIterations;
+ completedIterations++) {
// new centroid
arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
@@ -83,8 +127,7 @@ Cluster(const MatType& data,
// update the centroid.
allCentroids.col(i) = newCentroid;
- if (arma::norm(mhVector, 2) < stopThresh ||
- completedIterations > maxIterations) {
+ if (arma::norm(mhVector, 2) < stopThresh) {
break;
}
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index b85a176..c77f4f7 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -39,7 +39,8 @@ PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
"is less than stopThresh, iterations will terminate. ", "s", 1e-3);
PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
- "one will be removed. ", "d", 1.0);
+ "one will be removed. "
+ "If it's negative, an estimation will be given. ", "d", -1.0);
int main(int argc, char** argv) {
More information about the mlpack-git
mailing list