[mlpack-git] master: add metric (8121657)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:08 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 8121657608f6190cdd5e11c52232a6abe025c4c5
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Tue Jan 27 22:25:58 2015 +0800
add metric
>---------------------------------------------------------------
8121657608f6190cdd5e11c52232a6abe025c4c5
src/mlpack/methods/mean_shift/mean_shift.hpp | 14 +++++++++++++-
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 17 +++++++++++++----
2 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index d866b82..6a0c809 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -10,6 +10,7 @@
#include <mlpack/core.hpp>
#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
namespace mlpack {
namespace meanshift /** Mean Shift clustering. */ {
@@ -35,6 +36,7 @@ namespace meanshift /** Mean Shift clustering. */ {
*/
template<typename KernelType = kernel::GaussianKernel,
+ typename MetricType = metric::EuclideanDistance,
typename MatType = arma::mat>
class MeanShift
{
@@ -48,11 +50,13 @@ class MeanShift
* @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh,
* iterations will terminate.
* @param kernel Optional KernelType object.
+ * @param metric Optional the metric to calculate distance.
*/
MeanShift(const double duplicateThresh = -1,
const size_t maxIterations = 1000,
const double stopThresh = 1e-3,
- const KernelType kernel = KernelType());
+ const KernelType kernel = KernelType(),
+ const MetricType metric = MetricType());
/**
@@ -89,6 +93,11 @@ class MeanShift
//! Modify the kernel.
KernelType& Kernel() { return kernel; }
+ //! Get the metric.
+ const MetricType& Metric() const { return metric; }
+ //! Modify the metric.
+ MetricType& Metric() { return metric; }
+
//! Get the duplicate thresh.
double DuplicateThresh() const { return duplicateThresh; }
//! Set the duplicate thresh.
@@ -111,6 +120,9 @@ class MeanShift
//! Instantiated kernel.
KernelType kernel;
+ //! Instantiated metric.
+ MetricType metric;
+
};
}; // namespace meanshift
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index fdcaac2..09a54a5 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -20,18 +20,22 @@ namespace meanshift {
* Construct the Mean Shift object.
*/
template<typename KernelType,
+ typename MetricType,
typename MatType>
MeanShift<
KernelType,
+ MetricType,
MatType>::
MeanShift(const double duplicateThresh,
const size_t maxIterations,
const double stopThresh,
- const KernelType kernel) :
+ const KernelType kernel,
+ const MetricType metric) :
duplicateThresh(duplicateThresh),
maxIterations(maxIterations),
stopThresh(stopThresh),
- kernel(kernel)
+ kernel(kernel),
+ metric(metric)
{
// Nothing to do.
}
@@ -39,15 +43,17 @@ MeanShift(const double duplicateThresh,
// Estimate duplicate thresh based on given dataset.
template<typename KernelType,
+ typename MetricType,
typename MatType>
double MeanShift<
KernelType,
+ MetricType,
MatType>::
estimateDuplicateThresh(const MatType &data) {
neighbor::NeighborSearch<
neighbor::NearestNeighborSort,
- metric::EuclideanDistance,
+ MetricType,
tree::BinarySpaceTree<bound::HRectBound<2>,
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
> neighborSearch(data);
@@ -75,9 +81,11 @@ estimateDuplicateThresh(const MatType &data) {
* assignments and centroids.
*/
template<typename KernelType,
+ typename MetricType,
typename MatType>
inline void MeanShift<
KernelType,
+ MetricType,
MatType>::
Cluster(const MatType& data,
arma::Col<size_t>& assignments,
@@ -109,7 +117,8 @@ Cluster(const MatType& data,
for (size_t j = 0; j < data.n_cols; ++j) {
// calc weight for each point
- double weight = kernel.Evaluate(allCentroids.col(i), data.col(j));
+ double weight = kernel.Evaluate(metric.Evaluate(allCentroids.col(i),
+ data.col(j)));
sumWeight += weight;
// update new centroid.
More information about the mlpack-git
mailing list