[mlpack-git] master: implement mean shift algorithm (daf77e7)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:12 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit daf77e78f3932e5e536d6ec34dc7ffd7d33a390b
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Sun Jan 18 10:33:46 2015 +0800
implement mean shift algorithm
>---------------------------------------------------------------
daf77e78f3932e5e536d6ec34dc7ffd7d33a390b
src/mlpack/methods/CMakeLists.txt | 1 +
.../{adaboost => mean_shift}/CMakeLists.txt | 17 ++-
src/mlpack/methods/mean_shift/mean_shift.hpp | 127 +++++++++++++++++++++
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 126 ++++++++++++++++++++
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 107 +++++++++++++++++
5 files changed, 368 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index 6215b2c..2b8b09a 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -11,6 +11,7 @@ set(DIRS
hmm
kernel_pca
kmeans
+ mean_shift
lars
linear_regression
local_coordinate_coding
diff --git a/src/mlpack/methods/adaboost/CMakeLists.txt b/src/mlpack/methods/mean_shift/CMakeLists.txt
similarity index 68%
copy from src/mlpack/methods/adaboost/CMakeLists.txt
copy to src/mlpack/methods/mean_shift/CMakeLists.txt
index 1320b9d..cb42e52 100644
--- a/src/mlpack/methods/adaboost/CMakeLists.txt
+++ b/src/mlpack/methods/mean_shift/CMakeLists.txt
@@ -1,10 +1,8 @@
-cmake_minimum_required(VERSION 2.8)
-
# Define the files we need to compile.
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
- adaboost.hpp
- adaboost_impl.hpp
+ mean_shift.hpp
+ mean_shift_impl.hpp
)
# Add directory name to sources.
@@ -16,12 +14,11 @@ endforeach()
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
-
-add_executable(adaboost
- adaboost_main.cpp
+# The main K-Means executable.
+add_executable(mean_shift
+ mean_shift_main.cpp
)
-target_link_libraries(adaboost
+target_link_libraries(mean_shift
mlpack
)
-
-install(TARGETS adaboost RUNTIME DESTINATION bin)
+install(TARGETS mean_shift RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
new file mode 100644
index 0000000..b572bc0
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file mean_shift.hpp
+ * @author Shangtong Zhang
+ *
+ * Mean Shift clustering
+ */
+
+#ifndef __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
+#define __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace meanshift /** Mean Shift clustering. */ {
+
+/**
+ * This class implements Mean Shift clustering.
+ * For each point in dataset, apply mean shift algorithm until maximum
+ * iterations or convergence.
+ * Then remove duplicate centroids.
+ *
+ * A simple example of how to run Mean Shift clustering is shown below.
+ *
+ * @code
+ * extern arma::mat data; // Dataset we want to run Mean Shift on.
+ * arma::Col<size_t> assignments; // Cluster assignments.
+ * arma::mat centroids; // Cluster centroids.
+ * extern int maxIterations; // Maximum number of iterations.
+ * extern double stopThresh; //
+ * extern double radius; //
+ *
+ * MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations,
+ * stopThresh, radius);
+ * meanShift.Cluster(dataset, assignments, centroids);
+ * @endcode
+ *
+ * @tparam MetricType The distance metric to use for this KMeans; see
+ * metric::LMetric for an example.
+ */
+
+template<typename MatType = arma::mat,
+ typename MetricType = metric::EuclideanDistance>
+class MeanShift
+{
+ public:
+ /**
+ * Create a Mean Shift object and set the parameters which Mean Shift
+ * will be run with.
+ *
+ * @param maxIterations Maximum number of iterations allowed before giving up
+ * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh,
+ * iterations will terminate.
+ * @param radius When iterating, take points within distance of
+ * radius into consideration and two centroids within distance
+ * of radius will be ragarded as one centroid.
+ * @param metric Optional MetricType object; for when the metric has state
+ * it needs to store.
+ */
+ MeanShift(const size_t maxIterations,
+ const double stopThresh,
+ const double radius,
+ const MetricType metric = MetricType());
+
+
+ /**
+ * Perform Mean Shift clustering on the data, returning a list of cluster
+ * assignments and centroids.
+ *
+ * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
+ * @param data Dataset to cluster.
+ * @param assignments Vector to store cluster assignments in.
+ * @param centroids Matrix in which centroids are stored.
+ */
+ void Cluster(const MatType& data,
+ arma::Col<size_t>& assignments,
+ arma::mat& centroids);
+
+ //! Get the maximum number of iterations.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Set the maximum number of iterations.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the stop thresh.
+ double StopThresh() const { return stopThresh; }
+ //! Set the stop thresh.
+ double& StopThresh() { return stopThresh; }
+
+ //! Get the radius of the concerning points.
+ double Radius() const { return radius; }
+ //! Set the radius of the concerning points.
+ double& Radius() { return radius; }
+
+ //! Get the distance metric.
+ const MetricType& Metric() const { return metric; }
+ //! Modify the distance metric.
+ MetricType& Metric() { return metric; }
+
+ private:
+
+ //! Maximum number of iterations before giving up.
+ size_t maxIterations;
+
+ /** If the 2-norm of the mean shift vector is less than stopThresh,
+ * iterations will terminate.
+ */
+ double stopThresh;
+
+ /** When iterating, take points within distance of
+ * radius into consideration and two centroids within distance
+ * of radius will be ragarded as one centroid.
+ */
+ double radius;
+
+ //! Instantiated distance metric.
+ MetricType metric;
+
+};
+
+}; // namespace meanshift
+}; // namespace mlpack
+
+// Include implementation.
+#include "mean_shift_impl.hpp"
+
+#endif // __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
\ No newline at end of file
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
new file mode 100644
index 0000000..a486064
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -0,0 +1,126 @@
+/**
+ * @file mean_shift_impl.hpp
+ * @author Shangtong Zhang
+ *
+ * Mean Shift clustering
+ */
+
+#include "mean_shift.hpp"
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace meanshift {
+
+/**
+ * Construct the Mean Shift object.
+ */
+template<typename MatType,
+ typename MetricType>
+MeanShift<
+ MatType,
+ MetricType>::
+MeanShift(const size_t maxIterations,
+ const double stopThresh,
+ const double radius,
+ const MetricType metric) :
+ maxIterations(maxIterations),
+ stopThresh(stopThresh),
+ radius(radius),
+ metric(metric)
+{
+ // Nothing to do.
+}
+
+/**
+ * Perform Mean Shift clustering on the data, returning a list of cluster
+ * assignments and centroids.
+ */
+template<typename MatType,
+ typename MetricType>
+inline void MeanShift<
+ MatType,
+ MetricType>::
+Cluster(const MatType& data,
+ arma::Col<size_t>& assignments,
+ arma::mat& centroids) {
+
+ // all centroids before remove duplicate ones.
+ arma::mat allCentroids(data.n_rows, data.n_cols);
+ assignments.set_size(data.n_cols);
+
+ // for each point in dataset, perform mean shift algorithm.
+ for (size_t i = 0; i < data.n_cols; ++i) {
+
+ size_t completedIterations = 0;
+
+ //initial centroid is the point itself.
+ allCentroids.col(i) = data.col(i);
+
+ while (true) {
+
+ // mean shift vector.
+ arma::Col<double> mhVector = arma::zeros(data.n_rows, 1);
+
+ // number of neighbouring points.
+ int vecCount = 0;
+ for (size_t j = 0; j < data.n_cols; ++j) {
+
+ // find neighbouring points.
+ double dist = metric.Evaluate(data.col(j), allCentroids.col(i));
+ if (dist < radius) {
+ vecCount ++;
+
+ // update mean shift vector.
+ mhVector += data.col(j) - allCentroids.col(i);
+ }
+ }
+ mhVector /= vecCount;
+
+ completedIterations ++;
+ // update centroid.
+ allCentroids.col(i) += mhVector;
+
+ if (arma::norm(mhVector, 2) < stopThresh ||
+ completedIterations > maxIterations) {
+ break;
+ }
+
+ }
+
+ }
+
+ // remove duplicate centroids.
+ for (size_t i = 0; i < allCentroids.n_cols; ++i) {
+
+ bool isDuplicated = false;
+
+ /**
+ * if a centroid is a neighbouring point of existing points,
+ * remove it and update corresponding assignments.
+ */
+ for (size_t j = 0; j < centroids.n_cols; ++j) {
+ arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
+ if (norm(delta, 2) < radius) {
+ isDuplicated = true;
+ assignments(i) = j;
+ break;
+ }
+ }
+
+ if (!isDuplicated) {
+
+ // this centroid is a new centroid.
+ if (centroids.n_cols == 0) {
+ centroids.insert_cols(0, allCentroids.col(i));
+ assignments(i) = 0;
+ } else {
+ centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
+ assignments(i) = centroids.n_cols - 1;
+ }
+ }
+ }
+
+}
+
+}; // namespace meanshift
+}; // namespace mlpack
\ No newline at end of file
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
new file mode 100644
index 0000000..85b3013
--- /dev/null
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -0,0 +1,107 @@
+/**
+ * @file mean_shift_main.cpp
+ * @author Shangtong Zhang
+ *
+ * Executable for running Mean Shift
+ */
+
+#include <mlpack/core.hpp>
+#include "mean_shift.hpp"
+
+using namespace mlpack;
+using namespace mlpack::meanshift;
+using namespace std;
+
+// Define parameters for the executable.
+PROGRAM_INFO("Mean Shift Clustering", "This program performs mean shift clustering "
+ "on the given dataset, storing the learned cluster assignments either as "
+ "a column of labels in the file containing the input dataset or in a "
+ "separate file. "
+ "\n\n");
+
+// Required options.
+PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
+PARAM_DOUBLE_REQ("stopThresh", "If the 2-norm of the mean shift vector "
+ "is less than stopThresh, iterations will terminate. ", "s");
+PARAM_DOUBLE_REQ("radius", "When iterating, take points within distance of "
+ "radius into consideration and two centroids within distance "
+ "of radius will be ragarded as one centroid. ", "r");
+
+// Output options.
+PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
+ "assignments will be added to the input dataset file. In this case, "
+ "--outputFile is overridden.", "P");
+PARAM_STRING("output_file", "File to write output labels or labeled data to.",
+ "o", "");
+PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
+ " be written to the given file.", "C", "");
+
+// Mean Shift configuration options.
+PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
+ "terminates.", "m", 1000);
+
+
+int main(int argc, char** argv) {
+
+ CLI::ParseCommandLine(argc, argv);
+
+ const string inputFile = CLI::GetParam<string>("inputFile");
+ const double stopThresh = CLI::GetParam<double>("stopThresh");
+ const double radius = CLI::GetParam<double>("radius");
+ const int maxIterations = CLI::GetParam<int>("max_iterations");
+
+ if (maxIterations < 0) {
+ Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
+ ")! Must be greater than or equal to 0." << endl;
+ }
+
+ // Make sure we have an output file if we're not doing the work in-place.
+ if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file") &&
+ !CLI::HasParam("centroid_file")) {
+ Log::Warn << "--output_file, --in_place, and --centroid_file are not set; "
+ << "no results will be saved." << std::endl;
+ }
+
+ arma::mat dataset;
+ data::Load(inputFile, dataset, true); // Fatal upon failure.
+ arma::mat centroids;
+ arma::Col<size_t> assignments;
+
+ MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, stopThresh, radius);
+ Timer::Start("clustering");
+ meanShift.Cluster(dataset, assignments, centroids);
+ Timer::Stop("clustering");
+
+ if (CLI::HasParam("in_place")) {
+ // Add the column of assignments to the dataset; but we have to convert
+ // them to type double first.
+ arma::vec converted(assignments.n_elem);
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ converted(i) = (double) assignments(i);
+
+ dataset.insert_rows(dataset.n_rows, trans(converted));
+
+ // Save the dataset.
+ data::Save(inputFile, dataset);
+ } else {
+
+ // Convert the assignments to doubles.
+ arma::vec converted(assignments.n_elem);
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ converted(i) = (double) assignments(i);
+
+ dataset.insert_rows(dataset.n_rows, trans(converted));
+
+ // Now save, in the different file.
+ string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, dataset);
+
+ }
+
+ // Should we write the centroids to a file?
+ if (CLI::HasParam("centroid_file")) {
+ data::Save(CLI::GetParam<std::string>("centroid_file"), centroids);
+ }
+
+
+}
\ No newline at end of file
More information about the mlpack-git
mailing list