[mlpack-git] master: estimation of duplicate thresh (1ea40f2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:10 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 1ea40f22f08840fbf65f35e139cfa5fd301b5a73
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Tue Jan 27 22:00:04 2015 +0800

    estimation of duplicate thresh


>---------------------------------------------------------------

1ea40f22f08840fbf65f35e139cfa5fd301b5a73
 src/mlpack/methods/mean_shift/mean_shift.hpp      | 10 ++++-
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 55 ++++++++++++++++++++---
 src/mlpack/methods/mean_shift/mean_shift_main.cpp |  3 +-
 3 files changed, 59 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 09a27d4..d866b82 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -43,19 +43,25 @@ class MeanShift
    * Create a Mean Shift object and set the parameters which Mean Shift
    * will be run with.
    *
-   * @param duplicateThresh If distance of two centroids is less than it, one will be removed.
+   * @param duplicateThresh If distance of two centroids is less than it, one will be removed. If this value is negative, an estimation will be given when clustering.
    * @param maxIterations Maximum number of iterations allowed before giving up
    * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh, 
    *        iterations will terminate.
    * @param kernel Optional KernelType object.
    */
-  MeanShift(const double duplicateThresh = 1.0,
+  MeanShift(const double duplicateThresh = -1,
             const size_t maxIterations = 1000,
             const double stopThresh = 1e-3,
             const KernelType kernel = KernelType());
   
   
   /**
+   * Give an estimation of duplicate thresh based on given dataset.
+   * @param data Dataset for estimation.
+   */
+  double estimateDuplicateThresh(const MatType& data);
+  
+  /**
    * Perform Mean Shift clustering on the data, returning a list of cluster
    * assignments and centroids.
    * 
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 0e4d8c7..15b5abb 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -7,6 +7,11 @@
 
 #include "mean_shift.hpp"
 #include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search_stat.hpp>
+
+
 
 namespace mlpack {
 namespace meanshift {
@@ -31,7 +36,41 @@ MeanShift(const double duplicateThresh,
   // Nothing to do.
 }
   
-/**
+  
+// Estimate duplicate thresh based on given dataset.
+template<typename KernelType,
+         typename MatType>
+double MeanShift<
+  KernelType,
+  MatType>::
+estimateDuplicateThresh(const MatType &data) {
+  
+  neighbor::NeighborSearch<
+    neighbor::NearestNeighborSort,
+    metric::EuclideanDistance,
+    tree::BinarySpaceTree<bound::HRectBound<2>,
+          neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
+    > neighborSearch(data);
+  
+  /** 
+   * For each point in dataset, 
+   * select nNeighbors nearest points and get nNeighbors distances. 
+   * Use the maximum distance to estimate the duplicate thresh.
+   */
+  size_t nNeighbors = (int)(data.n_cols * 0.3);
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+  neighborSearch.Search(nNeighbors, neighbors, distances);
+  
+  // Get max distance for each point.
+  arma::rowvec maxDistances = max(distances);
+  
+  // Calc and return the duplicate thresh.
+  return sum(maxDistances) / (double)data.n_cols;
+  
+}
+
+ /**
   * Perform Mean Shift clustering on the data, returning a list of cluster
   * assignments and centroids.
   */
@@ -44,6 +83,12 @@ Cluster(const MatType& data,
         arma::Col<size_t>& assignments,
         arma::mat& centroids) {
   
+  if (duplicateThresh < 0) {
+    // An invalid duplicate thresh is given, an estimation is needed.
+    duplicateThresh = estimateDuplicateThresh(data);
+  }
+  
+  
   // all centroids before remove duplicate ones.
   arma::mat allCentroids(data.n_rows, data.n_cols);
   assignments.set_size(data.n_cols);
@@ -51,12 +96,11 @@ Cluster(const MatType& data,
   // for each point in dataset, perform mean shift algorithm.
   for (size_t i = 0; i < data.n_cols; ++i) {
     
-    size_t completedIterations = 0;
-    
     //initial centroid is the point itself.
     allCentroids.col(i) = data.col(i);
     
-    while (true) {
+    for (size_t completedIterations = 0; completedIterations < maxIterations;
+         completedIterations++) {
       
       // new centroid
       arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
@@ -83,8 +127,7 @@ Cluster(const MatType& data,
       // update the centroid.
       allCentroids.col(i) = newCentroid;
       
-      if (arma::norm(mhVector, 2) < stopThresh ||
-          completedIterations > maxIterations) {
+      if (arma::norm(mhVector, 2) < stopThresh) {
         break;
       }
       
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index b85a176..c77f4f7 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -39,7 +39,8 @@ PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
              "is less than stopThresh, iterations will terminate. ", "s", 1e-3);
 PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
 PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
-             "one will be removed. ", "d", 1.0);
+             "one will be removed. "
+             "If it's negative, an estimation will be given. ", "d", -1.0);
 
 
 int main(int argc, char** argv) {



More information about the mlpack-git mailing list