[mlpack-git] master: involve mean to calculate new centroid (6511938)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:59 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 651193842da1aaa98be89f10f40ab100b5188697
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Wed Apr 22 15:21:34 2015 +0800

    involve mean to calculate new centroid


>---------------------------------------------------------------

651193842da1aaa98be89f10f40ab100b5188697
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 98 +++++++++++++++--------
 1 file changed, 66 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 002051f..0edf692 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -25,10 +25,11 @@ namespace meanshift {
 /**
   * Construct the Mean Shift object.
   */
-template<typename KernelType, typename MatType>
-MeanShift<KernelType, MatType>::MeanShift(const double radius,
-                                          const size_t maxIterations,
-                                          const KernelType kernel) :
+template<bool UseKernel, typename KernelType, typename MatType>
+MeanShift<UseKernel, KernelType, MatType>::
+MeanShift(const double radius,
+          const size_t maxIterations,
+          const KernelType kernel) :
     radius(radius),
     maxIterations(maxIterations),
     kernel(kernel)
@@ -36,16 +37,16 @@ MeanShift<KernelType, MatType>::MeanShift(const double radius,
   // Nothing to do.
 }
 
-template<typename KernelType, typename MatType>
-void MeanShift<KernelType, MatType>::Radius(double radius)
+template<bool UseKernel, typename KernelType, typename MatType>
+void MeanShift<UseKernel, KernelType, MatType>::Radius(double radius)
 {
   this->radius = radius;
 }
 
 // Estimate radius based on given dataset.
-template<typename KernelType, typename MatType>
-double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data,
-                                                      double ratio)
+template<bool UseKernel, typename KernelType, typename MatType>
+double MeanShift<UseKernel, KernelType, MatType>::
+EstimateRadius(const MatType& data, double ratio)
 {
   neighbor::AllkNN neighborSearch(data);
   /**
@@ -82,9 +83,9 @@ class less
   }
 };
 
-// Generate seeds form given data set
-template<typename KernelType, typename MatType>
-void MeanShift<KernelType, MatType>::genSeeds(
+// Generate seeds from given data set
+template<bool UseKernel, typename KernelType, typename MatType>
+void MeanShift<UseKernel, KernelType, MatType>::GenSeeds(
     const MatType& data,
     double binSize,
     int minFreq,
@@ -112,12 +113,60 @@ void MeanShift<KernelType, MatType>::genSeeds(
   seeds = seeds * binSize;
 }
 
+// Calculate new centroid with given kernel.
+template<bool UseKernel, typename KernelType, typename MatType>
+template<bool ApplyKernel>
+typename std::enable_if<ApplyKernel, bool>::type
+MeanShift<UseKernel, KernelType, MatType>::
+CalculateCentroid(const MatType& data,
+                  const std::vector<size_t>& neighbors,
+                  const std::vector<double>& distances,
+                  arma::colvec& centroid)
+{
+  double sumWeight = 0;
+  for (size_t i = 0; i < neighbors.size(); ++i)
+  {
+    if (distances[i] > 0)
+    {
+      double dist = distances[i] / radius;
+      double weight = kernel.Gradient(dist) / dist;
+      sumWeight += weight;
+      centroid += weight * data.unsafe_col(neighbors[i]);
+    }
+  }
+  if (sumWeight != 0)
+  {
+    centroid /= sumWeight;
+    return true;
+  }
+  return false;
+}
+
+// Calculate new centroid by mean.
+template<bool UseKernel, typename KernelType, typename MatType>
+template<bool ApplyKernel>
+typename std::enable_if<!ApplyKernel, bool>::type
+MeanShift<UseKernel, KernelType, MatType>::
+CalculateCentroid(const MatType& data,
+                  const std::vector<size_t>& neighbors,
+                  const std::vector<double>&, /*unused*/
+                  arma::colvec& centroid)
+{
+  for (size_t i = 0; i < neighbors.size(); ++i)
+  {
+    centroid += data.unsafe_col(neighbors[i]);
+  }
+  centroid /= neighbors.size();
+  return true;
+}
+  
+  
 /**
  * Perform Mean Shift clustering on the data set, returning a list of cluster
  * assignments and centroids.
  */
-template<typename KernelType, typename MatType>
-inline void MeanShift<KernelType, MatType>::Cluster(
+template<bool UseKernel, typename KernelType, typename MatType>
+inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
     const MatType& data,
     arma::Col<size_t>& assignments,
     arma::mat& centroids,
@@ -133,7 +182,7 @@ inline void MeanShift<KernelType, MatType>::Cluster(
   const MatType* pSeeds = &data;
   if (useSeeds)
   {
-    genSeeds(data, radius, 1, seeds);
+    GenSeeds(data, radius, 1, seeds);
     pSeeds = &seeds;
   }
   
@@ -158,25 +207,10 @@ inline void MeanShift<KernelType, MatType>::Cluster(
       // Store new centroid in this.
       arma::colvec newCentroid(pSeeds->n_rows, arma::fill::zeros);
       
-      double sumWeight = 0;
       rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
           neighbors, distances);
-      if (neighbors[0].size() <= 1)
-        break;
-      for (size_t j = 0; j < neighbors[0].size(); ++j)
-      {
-        if (distances[0][j] > 0)
-        {
-          distances[0][j] /= radius;
-          double weight = kernel.Gradient(distances[0][j]) / distances[0][j];
-          sumWeight += weight;
-          newCentroid += weight * data.unsafe_col(neighbors[0][j]);
-        }
-      }
-      
-      if (sumWeight != 0)
-        newCentroid /= sumWeight;
-      else
+      // Calculate new centroid.
+      if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
         newCentroid = allCentroids.unsafe_col(i);
 
       // If the mean shift vector is small enough, it has converged.



More information about the mlpack-git mailing list