[mlpack-git] master: only consider points within radius (6e1e8a9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:28 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 6e1e8a91127408f7238e7e25112e6926e164f394
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Sat Jan 31 20:38:34 2015 +0800

    only consider points within radius


>---------------------------------------------------------------

6e1e8a91127408f7238e7e25112e6926e164f394
 src/mlpack/core/kernels/gaussian_kernel.hpp       |  4 ++
 src/mlpack/methods/mean_shift/mean_shift.hpp      | 55 ++++++----------
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 79 +++++++++++++----------
 src/mlpack/methods/mean_shift/mean_shift_main.cpp | 16 ++---
 src/mlpack/tests/mean_shift_test.cpp              | 19 ------
 5 files changed, 74 insertions(+), 99 deletions(-)

diff --git a/src/mlpack/core/kernels/gaussian_kernel.hpp b/src/mlpack/core/kernels/gaussian_kernel.hpp
index 9e52638..22afb5e 100644
--- a/src/mlpack/core/kernels/gaussian_kernel.hpp
+++ b/src/mlpack/core/kernels/gaussian_kernel.hpp
@@ -75,6 +75,10 @@ class GaussianKernel
     return exp(gamma * std::pow(t, 2.0));
   }
   
+  double Gradient(const double t) const {
+    return gamma * exp(gamma * std::pow(t, 2.0));
+  }
+
   /**
    * Obtain the normalization constant of the Gaussian kernel.
    *
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 6a0c809..c45a759 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -28,7 +28,7 @@ namespace meanshift /** Mean Shift clustering. */ {
  * arma::Col<size_t> assignments; // Cluster assignments.
  * arma::mat centroids; // Cluster centroids.
  *
- * MeanShift<arma::mat, kernel::GaussianKernel> meanShift();
+ * MeanShift<> meanShift();
  * meanShift.Cluster(dataset, assignments, centroids);
  * @endcode
  *
@@ -36,7 +36,6 @@ namespace meanshift /** Mean Shift clustering. */ {
  */
   
 template<typename KernelType = kernel::GaussianKernel,
-         typename MetricType = metric::EuclideanDistance,
          typename MatType = arma::mat>
 class MeanShift
 {
@@ -45,25 +44,21 @@ class MeanShift
    * Create a Mean Shift object and set the parameters which Mean Shift
    * will be run with.
    *
-   * @param duplicateThresh If distance of two centroids is less than it, one will be removed. If this value is negative, an estimation will be given when clustering.
+   * @param radius If distance of two centroids is less than it, one will be removed. If this value isn't positive, an estimation will be given when clustering.
    * @param maxIterations Maximum number of iterations allowed before giving up
-   * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh, 
    *        iterations will terminate.
    * @param kernel Optional KernelType object.
-   * @param metric Optional the metric to calculate distance.
    */
-  MeanShift(const double duplicateThresh = -1,
+  MeanShift(const double radius = 0,
             const size_t maxIterations = 1000,
-            const double stopThresh = 1e-3,
-            const KernelType kernel = KernelType(),
-            const MetricType metric = MetricType());
+            const KernelType kernel = KernelType());
   
   
   /**
-   * Give an estimation of duplicate thresh based on given dataset.
+   * Give an estimation of radius based on given dataset.
    * @param data Dataset for estimation.
    */
-  double estimateDuplicateThresh(const MatType& data);
+  double estimateRadius(const MatType& data);
   
   /**
    * Perform Mean Shift clustering on the data, returning a list of cluster
@@ -83,46 +78,34 @@ class MeanShift
   //! Set the maximum number of iterations.
   size_t& MaxIterations() { return maxIterations; }
   
-  //! Get the stop thresh.
-  double StopThresh() const { return stopThresh; }
-  //! Set the stop thresh.
-  double& StopThresh() { return stopThresh; }
+  //! Get the radius.
+  double Radius() const { return radius; }
+  //! Set the radius.
+  void Radius(double radius);
   
   //! Get the kernel.
   const KernelType& Kernel() const { return kernel; }
   //! Modify the kernel.
   KernelType& Kernel() { return kernel; }
   
-  //! Get the metric.
-  const MetricType& Metric() const { return metric; }
-  //! Modify the metric.
-  MetricType& Metric() { return metric; }
-  
-  //! Get the duplicate thresh.
-  double DuplicateThresh() const { return duplicateThresh; }
-  //! Set the duplicate thresh.
-  double& DuplicateThresh() { return duplicateThresh; }
-  
  private:
   
-  // If distance of two centroids is less than duplicateThresh, one will be removed.
-  double duplicateThresh;
+  /**
+   * If distance of two centroids is less than radius, one will be removed.
+   * Points with distance to current centroid less than radius will be used
+   * to calculate new centroid.
+   */
+  double radius;
+  
+  // By storing radius * radius, we can speed up a little.
+  double squaredRadius;
   
   //! Maximum number of iterations before giving up.
   size_t maxIterations;
   
-  /** 
-   * If the 2-norm of the mean shift vector is less than stopThresh,
-   *  iterations will terminate.
-   */
-  double stopThresh;
-  
   //! Instantiated kernel.
   KernelType kernel;
   
-  //! Instantiated metric.
-  MetricType metric;
-  
 };
 
 }; // namespace meanshift
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 09a54a5..86fb811 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -20,40 +20,40 @@ namespace meanshift {
   * Construct the Mean Shift object.
   */
 template<typename KernelType,
-         typename MetricType,
          typename MatType>
 MeanShift<
   KernelType,
-  MetricType,
   MatType>::
-MeanShift(const double duplicateThresh,
+MeanShift(const double radius,
           const size_t maxIterations,
-          const double stopThresh,
-          const KernelType kernel,
-          const MetricType metric) :
-    duplicateThresh(duplicateThresh),
+          const KernelType kernel) :
     maxIterations(maxIterations),
-    stopThresh(stopThresh),
-    kernel(kernel),
-    metric(metric)
+    kernel(kernel)
 {
-  // Nothing to do.
+  Radius(radius);
 }
   
+template<typename KernelType,
+         typename MatType>
+void MeanShift<
+  KernelType,
+  MatType>::
+Radius(double radius) {
+  this->radius = radius;
+  squaredRadius = radius * radius;
+}
   
-// Estimate duplicate thresh based on given dataset.
+// Estimate radius based on given dataset.
 template<typename KernelType,
-         typename MetricType,
          typename MatType>
 double MeanShift<
   KernelType,
-  MetricType,
   MatType>::
-estimateDuplicateThresh(const MatType &data) {
+estimateRadius(const MatType &data) {
   
   neighbor::NeighborSearch<
     neighbor::NearestNeighborSort,
-    MetricType,
+    metric::EuclideanDistance,
     tree::BinarySpaceTree<bound::HRectBound<2>,
           neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
     > neighborSearch(data);
@@ -76,27 +76,24 @@ estimateDuplicateThresh(const MatType &data) {
   
 }
 
- /**
-  * Perform Mean Shift clustering on the data, returning a list of cluster
-  * assignments and centroids.
-  */
+/**
+ * Perform Mean Shift clustering on the data, returning a list of cluster
+ * assignments and centroids.
+ */
 template<typename KernelType,
-         typename MetricType,
          typename MatType>
 inline void MeanShift<
     KernelType,
-    MetricType,
     MatType>::
 Cluster(const MatType& data,
         arma::Col<size_t>& assignments,
         arma::mat& centroids) {
   
-  if (duplicateThresh < 0) {
-    // An invalid duplicate thresh is given, an estimation is needed.
-    duplicateThresh = estimateDuplicateThresh(data);
+  if (radius <= 0) {
+    // An invalid radius is given, an estimation is needed.
+    Radius(estimateRadius(data));
   }
   
-  
   // all centroids before remove duplicate ones.
   arma::mat allCentroids(data.n_rows, data.n_cols);
   assignments.set_size(data.n_cols);
@@ -110,19 +107,30 @@ Cluster(const MatType& data,
     for (size_t completedIterations = 0; completedIterations < maxIterations;
          completedIterations++) {
       
-      // new centroid
+      // to store new centroid
       arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
       
       double sumWeight = 0;
+      
+      // Go through all the points
       for (size_t j = 0; j < data.n_cols; ++j) {
         
-        // calc weight for each point
-        double weight = kernel.Evaluate(metric.Evaluate(allCentroids.col(i),
-                                                       data.col(j)));
-        sumWeight += weight;
+        // Calculate the distance between old centroid and current point.
+        double squaredDist = metric::SquaredEuclideanDistance::
+                            Evaluate(allCentroids.col(i), data.col(j));
+        
+        // If current point is near the old centroid
+        if (squaredDist < squaredRadius) {
+          
+          // calculate weight for current point
+          double weight = kernel.Gradient(squaredDist / squaredRadius);
           
-        // update new centroid.
-        newCentroid += weight * data.col(j);
+          sumWeight += weight;
+          
+          // update new centroid.
+          newCentroid += weight * data.col(j);
+          
+        }
         
       }
       
@@ -134,13 +142,14 @@ Cluster(const MatType& data,
       // update the centroid.
       allCentroids.col(i) = newCentroid;
       
-      if (arma::norm(mhVector, 2) < stopThresh) {
+      // If the 2-norm of mean shift vector is small enough, it has converged.
+      if (arma::norm(mhVector, 2) < 1e-3 * radius) {
         
         // Determine if the new centroid is duplicate with old ones.
         bool isDuplicated = false;
         for (size_t k = 0; k < centroids.n_cols; ++k) {
           arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
-          if (norm(delta, 2) < duplicateThresh) {
+          if (norm(delta, 2) < radius) {
             isDuplicated = true;
             assignments(i) = k;
             break;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index d25970f..c53d29f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -35,23 +35,22 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
 // Mean Shift configuration options.
 PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
           "terminates.", "m", 1000);
-PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
-             "is less than stopThresh, iterations will terminate. ", "s", 1e-3);
+
 PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
-PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
+PARAM_DOUBLE("radius", "If distance of two centroids is less than radius "
              "one will be removed. "
-             "If it's negative, an estimation will be given. ", "d", -1.0);
-
+             "If it isn't positive, an estimation will be given. "
+             "Points with distance to current centroid less than radius "
+             "will be used to calculate new centroid. ", "r", 0);
 
 int main(int argc, char** argv) {
   
   CLI::ParseCommandLine(argc, argv);
   
   const string inputFile = CLI::GetParam<string>("inputFile");
-  const double stopThresh = CLI::GetParam<double>("stopThresh");
+  const double radius = CLI::GetParam<double>("radius");
   const double bandwidth = CLI::GetParam<double>("bandwidth");
   const int maxIterations = CLI::GetParam<int>("max_iterations");
-  const double duplicateThresh = CLI::GetParam<double>("duplicateThresh");
   
   if (maxIterations < 0) {
     Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
@@ -72,8 +71,7 @@ int main(int argc, char** argv) {
   
   kernel::GaussianKernel kernel(bandwidth);
   
-  MeanShift<> meanShift(duplicateThresh,
-                                                         maxIterations, stopThresh, kernel);
+  MeanShift<> meanShift(radius, maxIterations, kernel);
   Timer::Start("clustering");
   meanShift.Cluster(dataset, assignments, centroids);
   Timer::Stop("clustering");
diff --git a/src/mlpack/tests/mean_shift_test.cpp b/src/mlpack/tests/mean_shift_test.cpp
index 6088310..f5c071a 100644
--- a/src/mlpack/tests/mean_shift_test.cpp
+++ b/src/mlpack/tests/mean_shift_test.cpp
@@ -85,23 +85,4 @@ BOOST_AUTO_TEST_CASE(MeanShiftSimpleTest) {
   
 }
 
-/**
- * When duplicate thresh is set to 0, any centroids shouldn't removed.
- */
-BOOST_AUTO_TEST_CASE(ZeroDuplicateThreshTest) {
-    
-    // Set the duplicate thresh to 0
-    MeanShift<> meanShift(0);
-    
-    arma::Col<size_t> assignments;
-    arma::mat centroids;
-    meanShift.Cluster((arma::mat) trans(meanShiftData), assignments, centroids);
-    
-    /**
-     * Make sure the number of centroids is equal to 
-     * the number of vectors in dataset.
-     */
-    BOOST_REQUIRE_EQUAL(centroids.n_cols, meanShiftData.n_rows);
-}
-
 BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file



More information about the mlpack-git mailing list