[mlpack-git] master: support for selecting kernel (9543f79)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:22 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 9543f793d203e5202c12d650a93f0d075af70ddd
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Sun Jan 18 22:17:11 2015 +0800
support for selecting kernel
>---------------------------------------------------------------
9543f793d203e5202c12d650a93f0d075af70ddd
src/mlpack/methods/mean_shift/mean_shift.hpp | 62 ++++++++++-------------
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 52 ++++++++++---------
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 20 +++++---
3 files changed, 67 insertions(+), 67 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index b572bc0..95331e6 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -10,7 +10,7 @@
#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
namespace mlpack {
namespace meanshift /** Mean Shift clustering. */ {
@@ -27,21 +27,16 @@ namespace meanshift /** Mean Shift clustering. */ {
* extern arma::mat data; // Dataset we want to run Mean Shift on.
* arma::Col<size_t> assignments; // Cluster assignments.
* arma::mat centroids; // Cluster centroids.
- * extern int maxIterations; // Maximum number of iterations.
- * extern double stopThresh; //
- * extern double radius; //
*
- * MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations,
- * stopThresh, radius);
+ * MeanShift<arma::mat, kernel::GaussianKernel> meanShift();
* meanShift.Cluster(dataset, assignments, centroids);
* @endcode
*
- * @tparam MetricType The distance metric to use for this KMeans; see
- * metric::LMetric for an example.
+ * @tparam KernelType the kernel to use.
*/
template<typename MatType = arma::mat,
- typename MetricType = metric::EuclideanDistance>
+ typename KernelType = kernel::GaussianKernel>
class MeanShift
{
public:
@@ -49,19 +44,16 @@ class MeanShift
* Create a Mean Shift object and set the parameters which Mean Shift
* will be run with.
*
+ * @param duplicateThresh If distance of two centroids is less than it, one will be removed.
* @param maxIterations Maximum number of iterations allowed before giving up
* @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh,
- * iterations will terminate.
- * @param radius When iterating, take points within distance of
- * radius into consideration and two centroids within distance
- * of radius will be ragarded as one centroid.
- * @param metric Optional MetricType object; for when the metric has state
- * it needs to store.
+ * iterations will terminate.
+ * @param kernel Optional KernelType object.
*/
- MeanShift(const size_t maxIterations,
- const double stopThresh,
- const double radius,
- const MetricType metric = MetricType());
+ MeanShift(const double duplicateThresh = 1.0,
+ const size_t maxIterations = 1000,
+ const double stopThresh = 1e-3,
+ const KernelType kernel = KernelType());
/**
@@ -87,34 +79,32 @@ class MeanShift
//! Set the stop thresh.
double& StopThresh() { return stopThresh; }
- //! Get the radius of the concerning points.
- double Radius() const { return radius; }
- //! Set the radius of the concerning points.
- double& Radius() { return radius; }
+ //! Get the kernel.
+ const KernelType& Kernel() const { return kernel; }
+ //! Modify the kernel.
+ KernelType& Kernel() { return kernel; }
- //! Get the distance metric.
- const MetricType& Metric() const { return metric; }
- //! Modify the distance metric.
- MetricType& Metric() { return metric; }
+ //! Get the duplicate thresh.
+ double DuplicateThresh() const { return duplicateThresh; }
+ //! Set the duplicate thresh.
+ double& DuplicateThresh() { return duplicateThresh; }
private:
+ // If distance of two centroids is less than duplicateThresh, one will be removed.
+ double duplicateThresh;
+
//! Maximum number of iterations before giving up.
size_t maxIterations;
- /** If the 2-norm of the mean shift vector is less than stopThresh,
+ /**
+ * If the 2-norm of the mean shift vector is less than stopThresh,
* iterations will terminate.
*/
double stopThresh;
- /** When iterating, take points within distance of
- * radius into consideration and two centroids within distance
- * of radius will be ragarded as one centroid.
- */
- double radius;
-
- //! Instantiated distance metric.
- MetricType metric;
+ //! Instantiated kernel.
+ KernelType kernel;
};
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index a486064..5b51341 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -7,6 +7,7 @@
#include "mean_shift.hpp"
#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
namespace mlpack {
namespace meanshift {
@@ -15,18 +16,18 @@ namespace meanshift {
* Construct the Mean Shift object.
*/
template<typename MatType,
- typename MetricType>
+ typename KernelType>
MeanShift<
MatType,
- MetricType>::
-MeanShift(const size_t maxIterations,
+ KernelType>::
+MeanShift(const double duplicateThresh,
+ const size_t maxIterations,
const double stopThresh,
- const double radius,
- const MetricType metric) :
+ const KernelType kernel) :
+ duplicateThresh(duplicateThresh),
maxIterations(maxIterations),
stopThresh(stopThresh),
- radius(radius),
- metric(metric)
+ kernel(kernel)
{
// Nothing to do.
}
@@ -36,10 +37,10 @@ MeanShift(const size_t maxIterations,
* assignments and centroids.
*/
template<typename MatType,
- typename MetricType>
+ typename KernelType>
inline void MeanShift<
MatType,
- MetricType>::
+ KernelType>::
Cluster(const MatType& data,
arma::Col<size_t>& assignments,
arma::mat& centroids) {
@@ -58,27 +59,30 @@ Cluster(const MatType& data,
while (true) {
- // mean shift vector.
- arma::Col<double> mhVector = arma::zeros(data.n_rows, 1);
+ // new centroid
+ arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
- // number of neighbouring points.
- int vecCount = 0;
+ double sumWeight = 0;
for (size_t j = 0; j < data.n_cols; ++j) {
- // find neighbouring points.
- double dist = metric.Evaluate(data.col(j), allCentroids.col(i));
- if (dist < radius) {
- vecCount ++;
+ // calc weight for each point
+ double weight = kernel.Evaluate(allCentroids.col(i), data.col(j));
+ sumWeight += weight;
+
+ // update new centroid.
+ newCentroid += weight * data.col(j);
- // update mean shift vector.
- mhVector += data.col(j) - allCentroids.col(i);
- }
}
- mhVector /= vecCount;
+
+ newCentroid /= sumWeight;
completedIterations ++;
- // update centroid.
- allCentroids.col(i) += mhVector;
+
+ // calc the mean shift vector.
+ arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
+
+ // update the centroid.
+ allCentroids.col(i) = newCentroid;
if (arma::norm(mhVector, 2) < stopThresh ||
completedIterations > maxIterations) {
@@ -100,7 +104,7 @@ Cluster(const MatType& data,
*/
for (size_t j = 0; j < centroids.n_cols; ++j) {
arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
- if (norm(delta, 2) < radius) {
+ if (norm(delta, 2) < duplicateThresh) {
isDuplicated = true;
assignments(i) = j;
break;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index 85b3013..983db63 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -6,6 +6,7 @@
*/
#include <mlpack/core.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
#include "mean_shift.hpp"
using namespace mlpack;
@@ -21,11 +22,6 @@ PROGRAM_INFO("Mean Shift Clustering", "This program performs mean shift clusteri
// Required options.
PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
-PARAM_DOUBLE_REQ("stopThresh", "If the 2-norm of the mean shift vector "
- "is less than stopThresh, iterations will terminate. ", "s");
-PARAM_DOUBLE_REQ("radius", "When iterating, take points within distance of "
- "radius into consideration and two centroids within distance "
- "of radius will be ragarded as one centroid. ", "r");
// Output options.
PARAM_FLAG("in_place", "If specified, a column containing the learned cluster "
@@ -39,6 +35,11 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
// Mean Shift configuration options.
PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
"terminates.", "m", 1000);
+PARAM_DOUBLE("stopThresh", "If the 2-norm of the mean shift vector "
+ "is less than stopThresh, iterations will terminate. ", "s", 1e-3);
+PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
+PARAM_DOUBLE("duplicateThresh", "If distance of two centroids is less than duplicate thresh "
+ "one will be removed. ", "d", 1.0);
int main(int argc, char** argv) {
@@ -47,8 +48,9 @@ int main(int argc, char** argv) {
const string inputFile = CLI::GetParam<string>("inputFile");
const double stopThresh = CLI::GetParam<double>("stopThresh");
- const double radius = CLI::GetParam<double>("radius");
+ const double bandwidth = CLI::GetParam<double>("bandwidth");
const int maxIterations = CLI::GetParam<int>("max_iterations");
+ const double duplicateThresh = CLI::GetParam<double>("duplicateThresh");
if (maxIterations < 0) {
Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
@@ -67,7 +69,11 @@ int main(int argc, char** argv) {
arma::mat centroids;
arma::Col<size_t> assignments;
- MeanShift<arma::mat, metric::EuclideanDistance> meanShift(maxIterations, stopThresh, radius);
+ kernel::GaussianKernel kernel;
+ kernel.Bandwidth(bandwidth);
+
+ MeanShift<arma::mat, kernel::GaussianKernel> meanShift(duplicateThresh,
+ maxIterations, stopThresh, kernel);
Timer::Start("clustering");
meanShift.Cluster(dataset, assignments, centroids);
Timer::Stop("clustering");
More information about the mlpack-git
mailing list