[mlpack-git] master: implement mean shift algorithm (daf77e7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:12 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit daf77e78f3932e5e536d6ec34dc7ffd7d33a390b
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Sun Jan 18 10:33:46 2015 +0800

    implement mean shift algorithm


>---------------------------------------------------------------

daf77e78f3932e5e536d6ec34dc7ffd7d33a390b
 src/mlpack/methods/CMakeLists.txt                  |   1 +
 .../{adaboost => mean_shift}/CMakeLists.txt        |  17 ++-
 src/mlpack/methods/mean_shift/mean_shift.hpp       | 127 +++++++++++++++++++++
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp  | 126 ++++++++++++++++++++
 src/mlpack/methods/mean_shift/mean_shift_main.cpp  | 107 +++++++++++++++++
 5 files changed, 368 insertions(+), 10 deletions(-)

diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index 6215b2c..2b8b09a 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -11,6 +11,7 @@ set(DIRS
   hmm
   kernel_pca
   kmeans
+  mean_shift
   lars
   linear_regression
   local_coordinate_coding
diff --git a/src/mlpack/methods/adaboost/CMakeLists.txt b/src/mlpack/methods/mean_shift/CMakeLists.txt
similarity index 68%
copy from src/mlpack/methods/adaboost/CMakeLists.txt
copy to src/mlpack/methods/mean_shift/CMakeLists.txt
index 1320b9d..cb42e52 100644
--- a/src/mlpack/methods/adaboost/CMakeLists.txt
+++ b/src/mlpack/methods/mean_shift/CMakeLists.txt
@@ -1,10 +1,8 @@
-cmake_minimum_required(VERSION 2.8)
-
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  adaboost.hpp
-  adaboost_impl.hpp
+  mean_shift.hpp
+  mean_shift_impl.hpp
 )
 
 # Add directory name to sources.
@@ -16,12 +14,11 @@ endforeach()
 # the parent scope).
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
-
-add_executable(adaboost
-  adaboost_main.cpp
+# The main K-Means executable.
+add_executable(mean_shift
+  mean_shift_main.cpp
 )
-target_link_libraries(adaboost
+target_link_libraries(mean_shift
   mlpack
 )
-
-install(TARGETS adaboost RUNTIME DESTINATION bin)
+install(TARGETS mean_shift RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
new file mode 100644
index 0000000..b572bc0
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file mean_shift.hpp
+ * @author Shangtong Zhang
+ *
+ * Mean Shift clustering
+ */
+
+#ifndef __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
+#define __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace meanshift /** Mean Shift clustering. */ {
+
+/**
+ * This class implements Mean Shift clustering.
+ * For each point in dataset, apply mean shift algorithm until maximum 
+ * iterations or convergence. 
+ * Then remove duplicate centroids.
+ * 
+ * A simple example of how to run Mean Shift clustering is shown below.
+ *
+ * @code
+ * extern arma::mat data; // Dataset we want to run Mean Shift on.
+ * arma::Col<size_t> assignments; // Cluster assignments.
+ * arma::mat centroids; // Cluster centroids.
+ * extern int maxIterations; // Maximum number of iterations.
+ * extern double stopThresh; //
+ * extern double radius; //
+ *
+ * MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, 
+ *          stopThresh, radius);
+ * meanShift.Cluster(dataset, assignments, centroids);
+ * @endcode
+ *
+ * @tparam MetricType The distance metric to use for this KMeans; see
+ *     metric::LMetric for an example.
+ */
+  
+template<typename MatType = arma::mat,
+         typename MetricType = metric::EuclideanDistance>
+class MeanShift
+{
+ public:
+  /**
+   * Create a Mean Shift object and set the parameters which Mean Shift
+   * will be run with.
+   *
+   * @param maxIterations Maximum number of iterations allowed before giving up
+   * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh, 
+   *     iterations will terminate.
+   * @param radius When iterating, take points within distance of
+   *     radius into consideration and two centroids within distance
+   *     of radius will be ragarded as one centroid.
+   * @param metric Optional MetricType object; for when the metric has state
+   *     it needs to store.
+   */
+  MeanShift(const size_t maxIterations,
+            const double stopThresh,
+            const double radius,
+            const MetricType metric = MetricType());
+  
+  
+  /**
+   * Perform Mean Shift clustering on the data, returning a list of cluster
+   * assignments and centroids.
+   * 
+   * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
+   * @param data Dataset to cluster.
+   * @param assignments Vector to store cluster assignments in.
+   * @param centroids Matrix in which centroids are stored.
+   */
+  void Cluster(const MatType& data,
+               arma::Col<size_t>& assignments,
+               arma::mat& centroids);
+  
+  //! Get the maximum number of iterations.
+  size_t MaxIterations() const { return maxIterations; }
+  //! Set the maximum number of iterations.
+  size_t& MaxIterations() { return maxIterations; }
+  
+  //! Get the stop thresh.
+  double StopThresh() const { return stopThresh; }
+  //! Set the stop thresh.
+  double& StopThresh() { return stopThresh; }
+  
+  //! Get the radius of the concerning points.
+  double Radius() const { return radius; }
+  //! Set the radius of the concerning points.
+  double& Radius() { return radius; }
+  
+  //! Get the distance metric.
+  const MetricType& Metric() const { return metric; }
+  //! Modify the distance metric.
+  MetricType& Metric() { return metric; }
+  
+ private:
+  
+  //! Maximum number of iterations before giving up.
+  size_t maxIterations;
+  
+  /** If the 2-norm of the mean shift vector is less than stopThresh,
+   *  iterations will terminate.
+   */
+  double stopThresh;
+  
+  /** When iterating, take points within distance of
+   *     radius into consideration and two centroids within distance
+   *     of radius will be ragarded as one centroid.
+   */
+  double radius;
+  
+  //! Instantiated distance metric.
+  MetricType metric;
+  
+};
+
+}; // namespace meanshift
+}; // namespace mlpack
+
+// Include implementation.
+#include "mean_shift_impl.hpp"
+
+#endif // __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
\ No newline at end of file
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
new file mode 100644
index 0000000..a486064
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -0,0 +1,126 @@
+/**
+ * @file mean_shift_impl.hpp
+ * @author Shangtong Zhang
+ *
+ * Mean Shift clustering
+ */
+
+#include "mean_shift.hpp"
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace meanshift {
+  
+/**
+  * Construct the Mean Shift object.
+  */
+template<typename MatType,
+         typename MetricType>
+MeanShift<
+    MatType,
+    MetricType>::
+MeanShift(const size_t maxIterations,
+          const double stopThresh,
+          const double radius,
+          const MetricType metric) :
+    maxIterations(maxIterations),
+    stopThresh(stopThresh),
+    radius(radius),
+    metric(metric)
+{
+  // Nothing to do.
+}
+
+/**
+  * Perform Mean Shift clustering on the data, returning a list of cluster
+  * assignments and centroids.
+  */
+template<typename MatType,
+         typename MetricType>
+inline void MeanShift<
+    MatType,
+    MetricType>::
+Cluster(const MatType& data,
+        arma::Col<size_t>& assignments,
+        arma::mat& centroids) {
+  
+  // all centroids before remove duplicate ones.
+  arma::mat allCentroids(data.n_rows, data.n_cols);
+  assignments.set_size(data.n_cols);
+  
+  // for each point in dataset, perform mean shift algorithm.
+  for (size_t i = 0; i < data.n_cols; ++i) {
+    
+    size_t completedIterations = 0;
+    
+    //initial centroid is the point itself.
+    allCentroids.col(i) = data.col(i);
+    
+    while (true) {
+      
+      // mean shift vector.
+      arma::Col<double> mhVector = arma::zeros(data.n_rows, 1);
+      
+      // number of neighbouring points.
+      int vecCount = 0;
+      for (size_t j = 0; j < data.n_cols; ++j) {
+        
+        // find neighbouring points.
+        double dist = metric.Evaluate(data.col(j), allCentroids.col(i));
+        if (dist < radius) {
+          vecCount ++;
+          
+          // update mean shift vector.
+          mhVector += data.col(j) - allCentroids.col(i);
+        }
+      }
+      mhVector /= vecCount;
+      
+      completedIterations ++;
+      // update centroid.
+      allCentroids.col(i) += mhVector;
+      
+      if (arma::norm(mhVector, 2) < stopThresh ||
+          completedIterations > maxIterations) {
+        break;
+      }
+      
+    }
+    
+  }
+  
+  // remove duplicate centroids.
+  for (size_t i = 0; i < allCentroids.n_cols; ++i) {
+    
+    bool isDuplicated = false;
+    
+    /** 
+     * if a centroid is a neighbouring point of existing points,
+     * remove it and update corresponding assignments.
+     */
+    for (size_t j = 0; j < centroids.n_cols; ++j) {
+      arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
+      if (norm(delta, 2) < radius) {
+        isDuplicated = true;
+        assignments(i) = j;
+        break;
+      }
+    }
+    
+    if (!isDuplicated) {
+      
+      // this centroid is a new centroid.
+      if (centroids.n_cols == 0) {
+        centroids.insert_cols(0, allCentroids.col(i));
+        assignments(i) = 0;
+      } else {
+        centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
+        assignments(i) = centroids.n_cols - 1;
+      }
+    }
+  }
+  
+}
+  
+}; // namespace meanshift
+}; // namespace mlpack
\ No newline at end of file
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
new file mode 100644
index 0000000..85b3013
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -0,0 +1,107 @@
+/**
+ * @file mean_shift_main.cpp
+ * @author Shangtong Zhang
+ *
+ * Executable for running Mean Shift
+ */
+
+#include <mlpack/core.hpp>
+#include "mean_shift.hpp"
+
+using namespace mlpack;
+using namespace mlpack::meanshift;
+using namespace std;
+
+// Define parameters for the executable.
+PROGRAM_INFO("Mean Shift Clustering", "This program performs mean shift clustering "
+             "on the given dataset, storing the learned cluster assignments either as "
+             "a column of labels in the file containing the input dataset or in a "
+             "separate file. "
+             "\n\n");
+
+// Required options.
+PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
+PARAM_DOUBLE_REQ("stopThresh", "If the 2-norm of the mean shift vector "
+                 "is less than stopThresh, iterations will terminate. ", "s");
+PARAM_DOUBLE_REQ("radius", "When iterating, take points within distance of "
+                 "radius into consideration and two centroids within distance "
+                 "of radius will be ragarded as one centroid. ", "r");
+
+// Output options.
+PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
+           "assignments will be added to the input dataset file.  In this case, "
+           "--outputFile is overridden.", "P");
+PARAM_STRING("output_file", "File to write output labels or labeled data to.",
+             "o", "");
+PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
+             " be written to the given file.", "C", "");
+
+// Mean Shift configuration options.
+PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
+          "terminates.", "m", 1000);
+
+
+int main(int argc, char** argv) {
+  
+  CLI::ParseCommandLine(argc, argv);
+  
+  const string inputFile = CLI::GetParam<string>("inputFile");
+  const double stopThresh = CLI::GetParam<double>("stopThresh");
+  const double radius = CLI::GetParam<double>("radius");
+  const int maxIterations = CLI::GetParam<int>("max_iterations");
+  
+  if (maxIterations < 0) {
+    Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
+    ")! Must be greater than or equal to 0." << endl;
+  }
+  
+  // Make sure we have an output file if we're not doing the work in-place.
+  if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file") &&
+      !CLI::HasParam("centroid_file")) {
+    Log::Warn << "--output_file, --in_place, and --centroid_file are not set; "
+    << "no results will be saved." << std::endl;
+  }
+  
+  arma::mat dataset;
+  data::Load(inputFile, dataset, true); // Fatal upon failure.
+  arma::mat centroids;
+  arma::Col<size_t> assignments;
+  
+  MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, stopThresh, radius);
+  Timer::Start("clustering");
+  meanShift.Cluster(dataset, assignments, centroids);
+  Timer::Stop("clustering");
+  
+  if (CLI::HasParam("in_place")) {
+    // Add the column of assignments to the dataset; but we have to convert
+    // them to type double first.
+    arma::vec converted(assignments.n_elem);
+    for (size_t i = 0; i < assignments.n_elem; i++)
+      converted(i) = (double) assignments(i);
+    
+    dataset.insert_rows(dataset.n_rows, trans(converted));
+    
+    // Save the dataset.
+    data::Save(inputFile, dataset);
+  } else {
+    
+    // Convert the assignments to doubles.
+    arma::vec converted(assignments.n_elem);
+    for (size_t i = 0; i < assignments.n_elem; i++)
+      converted(i) = (double) assignments(i);
+    
+    dataset.insert_rows(dataset.n_rows, trans(converted));
+    
+    // Now save, in the different file.
+    string outputFile = CLI::GetParam<string>("output_file");
+    data::Save(outputFile, dataset);
+
+  }
+  
+  // Should we write the centroids to a file?
+  if (CLI::HasParam("centroid_file")) {
+    data::Save(CLI::GetParam<std::string>("centroid_file"), centroids);
+  }
+  
+  
+}
\ No newline at end of file



More information about the mlpack-git mailing list