[mlpack-git] master: add metric (8121657)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:08 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 8121657608f6190cdd5e11c52232a6abe025c4c5
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Tue Jan 27 22:25:58 2015 +0800

    add metric


>---------------------------------------------------------------

8121657608f6190cdd5e11c52232a6abe025c4c5
 src/mlpack/methods/mean_shift/mean_shift.hpp      | 14 +++++++++++++-
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 17 +++++++++++++----
 2 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index d866b82..6a0c809 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -10,6 +10,7 @@
 
 #include <mlpack/core.hpp>
 #include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
 
 namespace mlpack {
 namespace meanshift /** Mean Shift clustering. */ {
@@ -35,6 +36,7 @@ namespace meanshift /** Mean Shift clustering. */ {
  */
   
 template<typename KernelType = kernel::GaussianKernel,
+         typename MetricType = metric::EuclideanDistance,
          typename MatType = arma::mat>
 class MeanShift
 {
@@ -48,11 +50,13 @@ class MeanShift
    * @param stopThresh If the 2-norm of the mean shift vector is less than stopThresh, 
    *        iterations will terminate.
    * @param kernel Optional KernelType object.
+   * @param metric Optional the metric to calculate distance.
    */
   MeanShift(const double duplicateThresh = -1,
             const size_t maxIterations = 1000,
             const double stopThresh = 1e-3,
-            const KernelType kernel = KernelType());
+            const KernelType kernel = KernelType(),
+            const MetricType metric = MetricType());
   
   
   /**
@@ -89,6 +93,11 @@ class MeanShift
   //! Modify the kernel.
   KernelType& Kernel() { return kernel; }
   
+  //! Get the metric.
+  const MetricType& Metric() const { return metric; }
+  //! Modify the metric.
+  MetricType& Metric() { return metric; }
+  
   //! Get the duplicate thresh.
   double DuplicateThresh() const { return duplicateThresh; }
   //! Set the duplicate thresh.
@@ -111,6 +120,9 @@ class MeanShift
   //! Instantiated kernel.
   KernelType kernel;
   
+  //! Instantiated metric.
+  MetricType metric;
+  
 };
 
 }; // namespace meanshift
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index fdcaac2..09a54a5 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -20,18 +20,22 @@ namespace meanshift {
   * Construct the Mean Shift object.
   */
 template<typename KernelType,
+         typename MetricType,
          typename MatType>
 MeanShift<
   KernelType,
+  MetricType,
   MatType>::
 MeanShift(const double duplicateThresh,
           const size_t maxIterations,
           const double stopThresh,
-          const KernelType kernel) :
+          const KernelType kernel,
+          const MetricType metric) :
     duplicateThresh(duplicateThresh),
     maxIterations(maxIterations),
     stopThresh(stopThresh),
-    kernel(kernel)
+    kernel(kernel),
+    metric(metric)
 {
   // Nothing to do.
 }
@@ -39,15 +43,17 @@ MeanShift(const double duplicateThresh,
   
 // Estimate duplicate thresh based on given dataset.
 template<typename KernelType,
+         typename MetricType,
          typename MatType>
 double MeanShift<
   KernelType,
+  MetricType,
   MatType>::
 estimateDuplicateThresh(const MatType &data) {
   
   neighbor::NeighborSearch<
     neighbor::NearestNeighborSort,
-    metric::EuclideanDistance,
+    MetricType,
     tree::BinarySpaceTree<bound::HRectBound<2>,
           neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
     > neighborSearch(data);
@@ -75,9 +81,11 @@ estimateDuplicateThresh(const MatType &data) {
   * assignments and centroids.
   */
 template<typename KernelType,
+         typename MetricType,
          typename MatType>
 inline void MeanShift<
     KernelType,
+    MetricType,
     MatType>::
 Cluster(const MatType& data,
         arma::Col<size_t>& assignments,
@@ -109,7 +117,8 @@ Cluster(const MatType& data,
       for (size_t j = 0; j < data.n_cols; ++j) {
         
         // calc weight for each point
-        double weight = kernel.Evaluate(allCentroids.col(i), data.col(j));
+        double weight = kernel.Evaluate(metric.Evaluate(allCentroids.col(i),
+                                                       data.col(j)));
         sumWeight += weight;
         
         // update new centroid.



More information about the mlpack-git mailing list