[mlpack-git] master: optimization (0347c76)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:24 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4

>---------------------------------------------------------------

commit 0347c760a3297215e5a4b299ff4d932248e23312
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Tue Jan 27 22:11:55 2015 +0800

    optimization


>---------------------------------------------------------------

0347c760a3297215e5a4b299ff4d932248e23312
 src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 57 ++++++++++-------------
 src/mlpack/methods/mean_shift/mean_shift_main.cpp |  3 +-
 2 files changed, 25 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 15b5abb..fdcaac2 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -119,8 +119,6 @@ Cluster(const MatType& data,
       
       newCentroid /= sumWeight;
       
-      completedIterations ++;
-      
       // calc the mean shift vector.
       arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
       
@@ -128,42 +126,35 @@ Cluster(const MatType& data,
       allCentroids.col(i) = newCentroid;
       
       if (arma::norm(mhVector, 2) < stopThresh) {
-        break;
-      }
-      
-    }
-    
-  }
         
-  // remove duplicate centroids.
-  for (size_t i = 0; i < allCentroids.n_cols; ++i) {
-    
-    bool isDuplicated = false;
-    
-    /** 
-     * if a centroid is a neighbouring point of existing points,
-     * remove it and update corresponding assignments.
-     */
-    for (size_t j = 0; j < centroids.n_cols; ++j) {
-      arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
-      if (norm(delta, 2) < duplicateThresh) {
-        isDuplicated = true;
-        assignments(i) = j;
+        // Determine if the new centroid is duplicate with old ones.
+        bool isDuplicated = false;
+        for (size_t k = 0; k < centroids.n_cols; ++k) {
+          arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
+          if (norm(delta, 2) < duplicateThresh) {
+            isDuplicated = true;
+            assignments(i) = k;
+            break;
+          }
+        }
+        
+        if (!isDuplicated) {
+          // this centroid is a new centroid.
+          if (centroids.n_cols == 0) {
+            centroids.insert_cols(0, allCentroids.col(i));
+            assignments(i) = 0;
+          } else {
+            centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
+            assignments(i) = centroids.n_cols - 1;
+          }
+        }
+        
+        // Get out of the loop.
         break;
       }
-    }
       
-    if (!isDuplicated) {
-      
-      // this centroid is a new centroid.
-      if (centroids.n_cols == 0) {
-        centroids.insert_cols(0, allCentroids.col(i));
-        assignments(i) = 0;
-      } else {
-        centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
-        assignments(i) = centroids.n_cols - 1;
-      }
     }
+    
   }
   
 }
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index c77f4f7..d25970f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -70,8 +70,7 @@ int main(int argc, char** argv) {
   arma::mat centroids;
   arma::Col<size_t> assignments;
   
-  kernel::GaussianKernel kernel;
-  kernel.Bandwidth(bandwidth);
+  kernel::GaussianKernel kernel(bandwidth);
   
   MeanShift<> meanShift(duplicateThresh,
                                                          maxIterations, stopThresh, kernel);



More information about the mlpack-git mailing list