[mlpack-git] master: support for selecting kernel (9543f79)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:22 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 9543f793d203e5202c12d650a93f0d075af70ddd
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Sun Jan 18 22:17:11 2015 +0800

    support for selecting kernel


>---------------------------------------------------------------

9543f793d203e5202c12d650a93f0d075af70ddd
 src/mlpack/methods/mean_shift/mean_shift.hpp      | 62 ++++++++++-------------
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 52 ++++++++++---------
 src/mlpack/methods/mean_shift/mean_shift_main.cpp | 20 +++++---
 3 files changed, 67 insertions(+), 67 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index b572bc0..95331e6 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -10,7 +10,7 @@
 
 #include <mlpack/core.hpp>
 
-#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
 
 namespace mlpack {
 namespace meanshift /** Mean Shift clustering. */ {
@@ -27,21 +27,16 @@ namespace meanshift /** Mean Shift clustering. */ {
  * extern arma::mat data; // Dataset we want to run Mean Shift on.
  * arma::Col<size_t> assignments; // Cluster assignments.
  * arma::mat centroids; // Cluster centroids.
- * extern int maxIterations; // Maximum number of iterations.
- * extern double stopThresh; //
- * extern double radius; //
  *
- * MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, 
- *          stopThresh, radius);
+ * MeanShift<arma::mat, kernel::GaussianKernel> meanShift();
  * meanShift.Cluster(dataset, assignments, centroids);
  * @endcode
  *
- * @tparam MetricType The distance metric to use for this KMeans; see
- *     metric::LMetric for an example.
+ * @tparam KernelType the kernel to use.
  */
   
 template<typename MatType = arma::mat,
-         typename MetricType = metric::EuclideanDistance>
+         typename KernelType = kernel::GaussianKernel>
 class MeanShift
 {
  public:
@@ -49,19 +44,16 @@ class MeanShift
    * Create a Mean Shift object and set the parameters which Mean Shift
    * will be run with.
    *
+   * @param duplicateThresh If distance of two centroids is less than it, one will be removed.
    * @param maxIterations Maximum number of iterations allowed before giving up
    * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh, 
-   *     iterations will terminate.
-   * @param radius When iterating, take points within distance of
-   *     radius into consideration and two centroids within distance
-   *     of radius will be ragarded as one centroid.
-   * @param metric Optional MetricType object; for when the metric has state
-   *     it needs to store.
+   *        iterations will terminate.
+   * @param kernel Optional KernelType object.
    */
-  MeanShift(const size_t maxIterations,
-            const double stopThresh,
-            const double radius,
-            const MetricType metric = MetricType());
+  MeanShift(const double duplicateThresh = 1.0,
+            const size_t maxIterations = 1000,
+            const double stopThresh = 1e-3,
+            const KernelType kernel = KernelType());
   
   
   /**
@@ -87,34 +79,32 @@ class MeanShift
   //! Set the stop thresh.
   double& StopThresh() { return stopThresh; }
   
-  //! Get the radius of the concerning points.
-  double Radius() const { return radius; }
-  //! Set the radius of the concerning points.
-  double& Radius() { return radius; }
+  //! Get the kernel.
+  const KernelType& Kernel() const { return kernel; }
+  //! Modify the kernel.
+  KernelType& Kernel() { return kernel; }
   
-  //! Get the distance metric.
-  const MetricType& Metric() const { return metric; }
-  //! Modify the distance metric.
-  MetricType& Metric() { return metric; }
+  //! Get the duplicate thresh.
+  double DuplicateThresh() const { return duplicateThresh; }
+  //! Set the duplicate thresh.
+  double& DuplicateThresh() { return duplicateThresh; }
   
  private:
   
+  // If distance of two centroids is less than duplicateThresh, one will be removed.
+  double duplicateThresh;
+  
   //! Maximum number of iterations before giving up.
   size_t maxIterations;
   
-  /** If the 2-norm of the mean shift vector is less than stopThresh,
+  /** 
+   * If the 2-norm of the mean shift vector is less than stopThresh,
    *  iterations will terminate.
    */
   double stopThresh;
   
-  /** When iterating, take points within distance of
-   *     radius into consideration and two centroids within distance
-   *     of radius will be ragarded as one centroid.
-   */
-  double radius;
-  
-  //! Instantiated distance metric.
-  MetricType metric;
+  //! Instantiated kernel.
+  KernelType kernel;
   
 };
 
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index a486064..5b51341 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -7,6 +7,7 @@
 
 #include "mean_shift.hpp"
 #include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
 
 namespace mlpack {
 namespace meanshift {
@@ -15,18 +16,18 @@ namespace meanshift {
   * Construct the Mean Shift object.
   */
 template<typename MatType,
-         typename MetricType>
+         typename KernelType>
 MeanShift<
     MatType,
-    MetricType>::
-MeanShift(const size_t maxIterations,
+    KernelType>::
+MeanShift(const double duplicateThresh,
+          const size_t maxIterations,
           const double stopThresh,
-          const double radius,
-          const MetricType metric) :
+          const KernelType kernel) :
+    duplicateThresh(duplicateThresh),
     maxIterations(maxIterations),
     stopThresh(stopThresh),
-    radius(radius),
-    metric(metric)
+    kernel(kernel)
 {
   // Nothing to do.
 }
@@ -36,10 +37,10 @@ MeanShift(const size_t maxIterations,
   * assignments and centroids.
   */
 template<typename MatType,
-         typename MetricType>
+         typename KernelType>
 inline void MeanShift<
     MatType,
-    MetricType>::
+    KernelType>::
 Cluster(const MatType& data,
         arma::Col<size_t>& assignments,
         arma::mat& centroids) {
@@ -58,27 +59,30 @@ Cluster(const MatType& data,
     
     while (true) {
       
-      // mean shift vector.
-      arma::Col<double> mhVector = arma::zeros(data.n_rows, 1);
+      // new centroid
+      arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
       
-      // number of neighbouring points.
-      int vecCount = 0;
+      double sumWeight = 0;
       for (size_t j = 0; j < data.n_cols; ++j) {
         
-        // find neighbouring points.
-        double dist = metric.Evaluate(data.col(j), allCentroids.col(i));
-        if (dist < radius) {
-          vecCount ++;
+        // calc weight for each point
+        double weight = kernel.Evaluate(allCentroids.col(i), data.col(j));
+        sumWeight += weight;
+        
+        // update new centroid.
+        newCentroid += weight * data.col(j);
         
-          // update mean shift vector.
-          mhVector += data.col(j) - allCentroids.col(i);
-        }
       }
-      mhVector /= vecCount;
+      
+      newCentroid /= sumWeight;
       
       completedIterations ++;
-      // update centroid.
-      allCentroids.col(i) += mhVector;
+      
+      // calc the mean shift vector.
+      arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
+      
+      // update the centroid.
+      allCentroids.col(i) = newCentroid;
       
       if (arma::norm(mhVector, 2) < stopThresh ||
           completedIterations > maxIterations) {
@@ -100,7 +104,7 @@ Cluster(const MatType& data,
      */
     for (size_t j = 0; j < centroids.n_cols; ++j) {
       arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
-      if (norm(delta, 2) < radius) {
+      if (norm(delta, 2) < duplicateThresh) {
         isDuplicated = true;
         assignments(i) = j;
         break;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index 85b3013..983db63 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -6,6 +6,7 @@
  */
 
 #include <mlpack/core.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
 #include "mean_shift.hpp"
 
 using namespace mlpack;
@@ -21,11 +22,6 @@ PROGRAM_INFO("Mean Shift Clustering", "This program performs mean shift clusteri
 
 // Required options.
 PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
-PARAM_DOUBLE_REQ("stopThresh", "If the 2-norm of the mean shift vector "
-                 "is less than stopThresh, iterations will terminate. ", "s");
-PARAM_DOUBLE_REQ("radius", "When iterating, take points within distance of "
-                 "radius into consideration and two centroids within distance "
-                 "of radius will be ragarded as one centroid. ", "r");
 
 // Output options.
 PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
@@ -39,6 +35,11 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
 // Mean Shift configuration options.
 PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
           "terminates.", "m", 1000);
+PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
+             "is less than stopThresh, iterations will terminate. ", "s", 1e-3);
+PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
+PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
+             "one will be removed. ", "d", 1.0);
 
 
 int main(int argc, char** argv) {
@@ -47,8 +48,9 @@ int main(int argc, char** argv) {
   
   const string inputFile = CLI::GetParam<string>("inputFile");
   const double stopThresh = CLI::GetParam<double>("stopThresh");
-  const double radius = CLI::GetParam<double>("radius");
+  const double bandwidth = CLI::GetParam<double>("bandwidth");
   const int maxIterations = CLI::GetParam<int>("max_iterations");
+  const double duplicateThresh = CLI::GetParam<double>("duplicateThresh");
   
   if (maxIterations < 0) {
     Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
@@ -67,7 +69,11 @@ int main(int argc, char** argv) {
   arma::mat centroids;
   arma::Col<size_t> assignments;
   
-  MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, stopThresh, radius);
+  kernel::GaussianKernel kernel;
+  kernel.Bandwidth(bandwidth);
+  
+  MeanShift<arma::mat, kernel::GaussianKernel> meanShift(duplicateThresh,
+                                                         maxIterations, stopThresh, kernel);
   Timer::Start("clustering");
   meanShift.Cluster(dataset, assignments, centroids);
   Timer::Stop("clustering");



More information about the mlpack-git mailing list