[mlpack-git] master: Fix merge failures. (d0a3e9a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 15:00:50 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/182d4a629c1b23f683dff7b284844e4e3e9f5cc4...e06ce8f3ac1170108c20c114e82cae10356d1301

>---------------------------------------------------------------

commit d0a3e9a3ae475ebc758e7309ba8c87a3b0b4d357
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Apr 29 14:57:50 2015 -0400

    Fix merge failures.


>---------------------------------------------------------------

d0a3e9a3ae475ebc758e7309ba8c87a3b0b4d357
 src/mlpack/methods/mean_shift/mean_shift.hpp      | 55 ++++++++++++++++++-----
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp |  9 ++--
 src/mlpack/methods/mean_shift/mean_shift_main.cpp |  5 +--
 3 files changed, 52 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 271632e..1bd0d8b 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -34,10 +34,13 @@ namespace meanshift /** Mean Shift clustering. */ {
  * meanShift.Cluster(dataset, assignments, centroids);
  * @endcode
  *
+ * @tparam UseKernel Use kernel or mean to calculate new centroid
+ *         If false, KernelType will be ignored.
  * @tparam KernelType the kernel to use.
  */
   
-template<typename KernelType = kernel::GaussianKernel,
+template<bool UseKernel = false,
+         typename KernelType = kernel::GaussianKernel,
          typename MatType = arma::mat>
 class MeanShift
 {
@@ -59,8 +62,9 @@ class MeanShift
   /**
    * Give an estimation of radius based on given dataset.
    * @param data Dataset for estimation.
+   * @param ratio How many neighbors to use
    */
-  double EstimateRadius(const MatType& data);
+  double EstimateRadius(const MatType& data, double ratio = 0.2);
   
   /**
    * Perform Mean Shift clustering on the data, returning a list of cluster
@@ -73,7 +77,8 @@ class MeanShift
    */
   void Cluster(const MatType& data,
                arma::Col<size_t>& assignments,
-               arma::mat& centroids);
+               arma::mat& centroids,
+               bool useSeeds = true);
   
   //! Get the maximum number of iterations.
   size_t MaxIterations() const { return maxIterations; }
@@ -93,15 +98,45 @@ class MeanShift
  private:
   
   /**
-   * A general approach to calculate the weight for a point.
+   * To speed up, we can generate some seeds from data set and use
+   * them as initial centroids rather than all the points in the data set.
    *
-   * @param centroid The centroid to calculate the weight
-   * @param point Calculate its weight
-   * @param weight Store the weight
-   * @return If true, the @point is near enough to the @centroid and @weight is valid,
-   *         If false, the @point is far from the @centroid and @weight is invalid.
+   * @param data The reference data set
+   * @param binSize It can be set equal to the estimated radius
+   * @param minFreq Usually 1 is enough
+   * @param seed Store generated sedds
    */
-  bool CalcWeight(const arma::colvec& centroid, const arma::colvec& point, double& weight);
+  void GenSeeds(const MatType& data, double binSize, int minFreq, MatType& seeds);
+  
+  /**
+   * Use kernel to calculate new centroid given dataset and valid neighbors.
+   *
+   * @param data The whole dataset
+   * @param neighbors Valid neighbors
+   * @param distances Distances to neighbors
+   # @param centroid Store calculated centroid
+   */
+  template<bool ApplyKernel = UseKernel>
+  typename std::enable_if<ApplyKernel, bool>::type
+  CalculateCentroid(const MatType& data,
+                    const std::vector<size_t>& neighbors,
+                    const std::vector<double>& distances,
+                    arma::colvec& centroid);
+  
+  /**
+   * Use mean to calculate new centroid given dataset and valid neighbors.
+   *
+   * @param data The whole dataset
+   * @param neighbors Valid neighbors
+   * @param distances Distances to neighbors
+   # @param centroid Store calculated centroid
+   */
+  template<bool ApplyKernel = UseKernel>
+  typename std::enable_if<!ApplyKernel, bool>::type
+  CalculateCentroid(const MatType& data,
+                    const std::vector<size_t>& neighbors,
+                    const std::vector<double>&, /*unused*/
+                    arma::colvec& centroid);
   
   /**
    * If distance of two centroids is less than radius, one will be removed.
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 0edf692..b3d3b2f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -174,7 +174,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
 {
   if (radius <= 0)
   {
-    // An invalid radius is given, an estimation is needed.
+    // An invalid radius is given; an estimation is needed.
     Radius(EstimateRadius(data));
   }
 
@@ -199,7 +199,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
   // For each seed, perform mean shift algorithm.
   for (size_t i = 0; i < pSeeds->n_cols; ++i)
   {
-    // Initial centroid is the point itself.
+    // Initial centroid is the seed itself.
     allCentroids.col(i) = pSeeds->unsafe_col(i);
     for (size_t completedIterations = 0; completedIterations < maxIterations;
          completedIterations++)
@@ -209,6 +209,9 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
       
       rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
           neighbors, distances);
+      if (neighbors[0].size() <= 1)
+        break;
+      
       // Calculate new centroid.
       if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
         newCentroid = allCentroids.unsafe_col(i);
@@ -222,7 +225,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
         for (size_t k = 0; k < centroids.n_cols; ++k)
         {
           const double distance = metric::EuclideanDistance::Evaluate(
-              allCentroids.col(i), centroids.col(k));
+              allCentroids.unsafe_col(i), centroids.unsafe_col(k));
           if (distance < radius)
           {
             isDuplicated = true;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index 109e3b5..ea323d9 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -37,7 +37,6 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
 PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
           "terminates.", "m", 1000);
 
-PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
 PARAM_DOUBLE("radius", "If distance of two centroids is less than radius "
              "one will be removed. "
              "If it isn't positive, an estimation will be given. "
@@ -71,9 +70,7 @@ int main(int argc, char** argv)
   arma::mat centroids;
   arma::Col<size_t> assignments;
 
-  GaussianKernel kernel(bandwidth);
-
-  MeanShift<> meanShift(radius, maxIterations, kernel);
+  MeanShift<> meanShift(radius, maxIterations);
 
   Timer::Start("clustering");
   Log::Info << "Performing mean shift clustering..." << endl;



More information about the mlpack-git mailing list