[mlpack-git] master: Style overhaul, and clarify some comments. (481aa09)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:49 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 481aa091d3315e5f1794dd316467e4ce6c05a332
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Mar 27 17:13:02 2015 +0000

    Style overhaul, and clarify some comments.


>---------------------------------------------------------------

481aa091d3315e5f1794dd316467e4ce6c05a332
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 147 ++++++++++------------
 1 file changed, 67 insertions(+), 80 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index b769190..4584ed0 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -22,50 +22,40 @@ namespace meanshift {
 /**
   * Construct the Mean Shift object.
   */
-template<typename KernelType,
-         typename MatType>
-MeanShift<
-  KernelType,
-  MatType>::
-MeanShift(const double radius,
-          const size_t maxIterations,
-          const KernelType kernel) :
+template<typename KernelType, typename MatType>
+MeanShift<KernelType, MatType>::MeanShift(const double radius,
+                                          const size_t maxIterations,
+                                          const KernelType kernel) :
     maxIterations(maxIterations),
-    kernel(kernel)
+    kernel(kernel),
 {
+  // Set the radius; estimate if needed.
   Radius(radius);
 }
 
-template<typename KernelType,
-         typename MatType>
-void MeanShift<
-  KernelType,
-  MatType>::
-Radius(double radius) {
+template<typename KernelType, typename MatType>
+void MeanShift<KernelType, MatType>::Radius(double radius)
+{
   this->radius = radius;
 }
 
 // Estimate radius based on given dataset.
-template<typename KernelType,
-         typename MatType>
-double MeanShift<
-  KernelType,
-  MatType>::
-EstimateRadius(const MatType &data) {
-
+template<typename KernelType, typename MatType>
+double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
+{
   neighbor::NeighborSearch<
-    neighbor::NearestNeighborSort,
-    metric::EuclideanDistance,
-    tree::BinarySpaceTree<bound::HRectBound<2>,
-          neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
-    > neighborSearch(data);
+      neighbor::NearestNeighborSort,
+      metric::EuclideanDistance,
+      tree::BinarySpaceTree<bound::HRectBound<2>,
+            neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
+  > neighborSearch(data);
 
   /**
-   * For each point in dataset,
-   * select nNeighbors nearest points and get nNeighbors distances.
-   * Use the maximum distance to estimate the duplicate thresh.
+   * For each point in dataset, select nNeighbors nearest points and get
+   * nNeighbors distances.  Use the maximum distance to estimate the duplicate
+   * threshhold.
    */
-  size_t nNeighbors = (int)(data.n_cols * 0.3);
+  size_t nNeighbors = size_t(data.n_cols * 0.3);
   arma::Mat<size_t> neighbors;
   arma::mat distances;
   neighborSearch.Search(nNeighbors, neighbors, distances);
@@ -74,102 +64,103 @@ EstimateRadius(const MatType &data) {
   arma::rowvec maxDistances = max(distances);
 
   // Calculate and return the radius.
-  return sum(maxDistances) / (double)data.n_cols;
-
+  return sum(maxDistances) / (double) data.n_cols;
 }
 
 // General way to calculate the weight of a data point.
-template <typename KernelType,
-          typename MatType>
-bool
-MeanShift<
-  KernelType,
-  MatType>::
-CalcWeight(const arma::colvec& centroid, const arma::colvec& point,
-           double& weight) {
-
+template<typename KernelType, typename MatType>
+bool MeanShift<KernelType, MatType>::CalcWeight(
+    const arma::colvec& centroid,
+    const arma::colvec& point,
+    double& weight)
+{
   double distance = EuclideanDistance::Evaluate(centroid, point);
-  if (distance >= radius || distance == 0) {
+  if (distance >= radius || distance == 0)
     return false;
-  }
+
   distance /= radius;
   weight = kernel.Gradient(distance) / distance;
   return true;
-
 }
 
 /**
  * Perform Mean Shift clustering on the data, returning a list of cluster
  * assignments and centroids.
  */
-template<typename KernelType,
-         typename MatType>
-inline void MeanShift<
-    KernelType,
-    MatType>::
-Cluster(const MatType& data,
-        arma::Col<size_t>& assignments,
-        arma::mat& centroids) {
-
-  if (radius <= 0) {
+template<typename KernelType, typename MatType>
+inline void MeanShift<KernelType, MatType>::Cluster(
+    const MatType& data,
+    arma::Col<size_t>& assignments,
+    arma::mat& centroids)
+{
+  if (radius <= 0)
+  {
     // An invalid radius is given, an estimation is needed.
     Radius(EstimateRadius(data));
   }
 
-  // all centroids before remove duplicate ones.
+  // Holds all centroids, before removing duplicate ones.
   arma::mat allCentroids(data.n_rows, data.n_cols);
   assignments.set_size(data.n_cols);
 
-  // for each point in dataset, perform mean shift algorithm.
-  for (size_t i = 0; i < data.n_cols; ++i) {
 
-    //initial centroid is the point itself.
+  // For each point in dataset, perform mean shift algorithm.
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    // Initial centroid is the point itself.
     allCentroids.col(i) = data.col(i);
 
     for (size_t completedIterations = 0; completedIterations < maxIterations;
-         completedIterations++) {
-
-      // to store new centroid
+         completedIterations++)
+    {
+      // Store new centroid in this.
       arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
 
       double sumWeight = 0;
 
       // Go through all the points
-      for (size_t j = 0; j < data.n_cols; ++j) {
-
+      for (size_t j = 0; j < data.n_cols; ++j)
+      {
         double weight = 0;
-        if (CalcWeight(allCentroids.col(i), data.col(j), weight)) {
+        if (CalcWeight(allCentroids.col(i), data.col(j), weight))
+        {
           sumWeight += weight;
           newCentroid += weight * data.col(j);
         }
-
       }
 
       newCentroid /= sumWeight;
 
-      // calc the mean shift vector.
+      // Calculate the mean shift vector.
       arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
 
       // If the mean shift vector is small enough, it has converged.
-      if (EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i)) < 1e-3 * radius) {
-
+      if (EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i)) <
+          1e-3 * radius)
+      {
         // Determine if the new centroid is duplicate with old ones.
         bool isDuplicated = false;
-        for (size_t k = 0; k < centroids.n_cols; ++k) {
+        for (size_t k = 0; k < centroids.n_cols; ++k)
+        {
           arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
-          if (norm(delta, 2) < radius) {
+          if (norm(delta, 2) < radius)
+          {
             isDuplicated = true;
             assignments(i) = k;
             break;
           }
         }
 
-        if (!isDuplicated) {
-          // this centroid is a new centroid.
-          if (centroids.n_cols == 0) {
+        if (!isDuplicated)
+        {
+          // This centroid is a new centroid.
+          if (centroids.n_cols == 0)
+          {
             centroids.insert_cols(0, allCentroids.col(i));
             assignments(i) = 0;
-          } else {
+          }
+          else
+          {
             centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
             assignments(i) = centroids.n_cols - 1;
           }
@@ -179,14 +170,10 @@ Cluster(const MatType& data,
         break;
       }
 
-
-      // update the centroid.
+      // Update the centroid.
       allCentroids.col(i) = newCentroid;
-
     }
-
   }
-
 }
 
 } // namespace meanshift



More information about the mlpack-git mailing list