[mlpack-git] master: generate seeds as initial centroids to speed up (04d234b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:55 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 04d234b789ca5c748f37d7f45b2fa8f3ed45f86a
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Sun Apr 12 00:01:13 2015 +0800

    generate seeds as initial centroids to speed up


>---------------------------------------------------------------

04d234b789ca5c748f37d7f45b2fa8f3ed45f86a
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 142 ++++++++++++++--------
 1 file changed, 93 insertions(+), 49 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 5d3efbc..002051f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -12,6 +12,9 @@
 #include <mlpack/core/metrics/lmetric.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search_stat.hpp>
+#include <mlpack/methods/range_search/range_search.hpp>
+
+#include "map"
 
 // In case it hasn't been included yet.
 #include "mean_shift.hpp"
@@ -41,16 +44,16 @@ void MeanShift<KernelType, MatType>::Radius(double radius)
 
 // Estimate radius based on given dataset.
 template<typename KernelType, typename MatType>
-double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
+double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data,
+                                                      double ratio)
 {
   neighbor::AllkNN neighborSearch(data);
-
   /**
    * For each point in dataset, select nNeighbors nearest points and get
    * nNeighbors distances.  Use the maximum distance to estimate the duplicate
    * threshhold.
    */
-  size_t nNeighbors = size_t(data.n_cols * 0.3);
+  size_t nNeighbors = size_t(data.n_cols * ratio);
   arma::Mat<size_t> neighbors;
   arma::mat distances;
   neighborSearch.Search(nNeighbors, neighbors, distances);
@@ -62,31 +65,63 @@ double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
   return sum(maxDistances) / (double) data.n_cols;
 }
 
-// General way to calculate the weight of a data point.
-template<typename KernelType, typename MatType>
-bool MeanShift<KernelType, MatType>::CalcWeight(
-    const arma::colvec& centroid,
-    const arma::colvec& point,
-    double& weight)
+// Class to compare two vector
+template <typename VecType>
+class less
 {
-  double distance = metric::EuclideanDistance::Evaluate(centroid, point);
-  if (distance >= radius || distance == 0)
+ public:
+  bool operator()(const VecType& first, const VecType& second) const
+  {
+    for (size_t i = 0; i < first.n_rows; ++i)
+    {
+      if (first[i] == second[i])
+        continue;
+      return first(i) < second(i);
+    }
     return false;
+  }
+};
+
+// Generate seeds form given data set
+template<typename KernelType, typename MatType>
+void MeanShift<KernelType, MatType>::genSeeds(
+    const MatType& data,
+    double binSize,
+    int minFreq,
+    MatType& seeds)
+{
+  typedef arma::colvec VecType;
+  std::map<VecType, int, less<VecType> > allSeeds;
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    VecType binnedPoint = arma::floor(data.unsafe_col(i) / binSize);
+    if (allSeeds.find(binnedPoint) == allSeeds.end())
+      allSeeds[binnedPoint] = 1;
+    else
+      allSeeds[binnedPoint]++;
+  }
+  
+  // Remove seeds with too few points
+  std::map<VecType, int, less<VecType> >::iterator it;
+  for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
+  {
+    if (it->second >= minFreq)
+      seeds.insert_cols(seeds.n_cols, it->first);
+  }
   
-  distance /= radius;
-  weight = kernel.Gradient(distance) / distance;
-  return true;
+  seeds = seeds * binSize;
 }
 
 /**
- * Perform Mean Shift clustering on the data, returning a list of cluster
+ * Perform Mean Shift clustering on the data set, returning a list of cluster
  * assignments and centroids.
  */
 template<typename KernelType, typename MatType>
 inline void MeanShift<KernelType, MatType>::Cluster(
     const MatType& data,
     arma::Col<size_t>& assignments,
-    arma::mat& centroids)
+    arma::mat& centroids,
+    bool useSeeds)
 {
   if (radius <= 0)
   {
@@ -94,44 +129,59 @@ inline void MeanShift<KernelType, MatType>::Cluster(
     Radius(EstimateRadius(data));
   }
 
-  // Holds all centroids, before removing duplicate ones.
-  arma::mat allCentroids(data.n_rows, data.n_cols);
+  MatType seeds;
+  const MatType* pSeeds = &data;
+  if (useSeeds)
+  {
+    genSeeds(data, radius, 1, seeds);
+    pSeeds = &seeds;
+  }
+  
+  // Holds all centroids before removing duplicate ones.
+  arma::mat allCentroids(pSeeds->n_rows, pSeeds->n_cols);
+  
   assignments.set_size(data.n_cols);
 
+  range::RangeSearch<> rangeSearcher(data);
+  math::Range validRadius(0, radius);
+  std::vector<std::vector<size_t> > neighbors;
+  std::vector<std::vector<double> > distances;
 
-  // For each point in dataset, perform mean shift algorithm.
-  for (size_t i = 0; i < data.n_cols; ++i)
+  // For each seed, perform mean shift algorithm.
+  for (size_t i = 0; i < pSeeds->n_cols; ++i)
   {
     // Initial centroid is the point itself.
-    allCentroids.col(i) = data.col(i);
-
+    allCentroids.col(i) = pSeeds->unsafe_col(i);
     for (size_t completedIterations = 0; completedIterations < maxIterations;
          completedIterations++)
     {
       // Store new centroid in this.
-      arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
+      arma::colvec newCentroid(pSeeds->n_rows, arma::fill::zeros);
       
       double sumWeight = 0;
-
-      // Go through all the points
-      for (size_t j = 0; j < data.n_cols; ++j)
+      rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
+          neighbors, distances);
+      if (neighbors[0].size() <= 1)
+        break;
+      for (size_t j = 0; j < neighbors[0].size(); ++j)
       {
-        double weight = 0;
-        if (CalcWeight(allCentroids.col(i), data.col(j), weight))
+        if (distances[0][j] > 0)
         {
+          distances[0][j] /= radius;
+          double weight = kernel.Gradient(distances[0][j]) / distances[0][j];
           sumWeight += weight;
-          newCentroid += weight * data.col(j);
+          newCentroid += weight * data.unsafe_col(neighbors[0][j]);
         }
       }
       
-      newCentroid /= sumWeight;
-
-      // Calculate the mean shift vector.
-      arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
+      if (sumWeight != 0)
+        newCentroid /= sumWeight;
+      else
+        newCentroid = allCentroids.unsafe_col(i);
 
       // If the mean shift vector is small enough, it has converged.
-      if (metric::EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i))
-          < 1e-3 * radius)
+      if (metric::EuclideanDistance::Evaluate(newCentroid, allCentroids.unsafe_col(i)) <
+          1e-3 * radius)
       {
         // Determine if the new centroid is duplicate with old ones.
         bool isDuplicated = false;
@@ -142,25 +192,12 @@ inline void MeanShift<KernelType, MatType>::Cluster(
           if (distance < radius)
           {
             isDuplicated = true;
-            assignments(i) = k;
             break;
           }
         }
 
         if (!isDuplicated)
-        {
-          // This centroid is a new centroid.
-          if (centroids.n_cols == 0)
-          {
-            centroids.insert_cols(0, allCentroids.col(i));
-            assignments(i) = 0;
-          }
-          else
-          {
-            centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
-            assignments(i) = centroids.n_cols - 1;
-          }
-        }
+          centroids.insert_cols(centroids.n_cols, allCentroids.unsafe_col(i));
         
         // Get out of the loop.
         break;
@@ -170,6 +207,13 @@ inline void MeanShift<KernelType, MatType>::Cluster(
       allCentroids.col(i) = newCentroid;
     }
   }
+  
+  // Assign centroids to each point
+  neighbor::AllkNN neighborSearcher(centroids, data);
+  arma::mat neighborDistances;
+  arma::Mat<size_t> resultingNeighbors;
+  neighborSearcher.Search(1, resultingNeighbors, neighborDistances);
+  assignments = resultingNeighbors.t();
 }
 
 } // namespace meanshift



More information about the mlpack-git mailing list