[mlpack-git] master: optimization (0347c76)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:24 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 0347c760a3297215e5a4b299ff4d932248e23312
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Tue Jan 27 22:11:55 2015 +0800
optimization
>---------------------------------------------------------------
0347c760a3297215e5a4b299ff4d932248e23312
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 57 ++++++++++-------------
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 3 +-
2 files changed, 25 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 15b5abb..fdcaac2 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -119,8 +119,6 @@ Cluster(const MatType& data,
newCentroid /= sumWeight;
- completedIterations ++;
-
// calc the mean shift vector.
arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
@@ -128,42 +126,35 @@ Cluster(const MatType& data,
allCentroids.col(i) = newCentroid;
if (arma::norm(mhVector, 2) < stopThresh) {
- break;
- }
-
- }
-
- }
- // remove duplicate centroids.
- for (size_t i = 0; i < allCentroids.n_cols; ++i) {
-
- bool isDuplicated = false;
-
- /**
- * if a centroid is a neighbouring point of existing points,
- * remove it and update corresponding assignments.
- */
- for (size_t j = 0; j < centroids.n_cols; ++j) {
- arma::Col<double> delta = allCentroids.col(i) - centroids.col(j);
- if (norm(delta, 2) < duplicateThresh) {
- isDuplicated = true;
- assignments(i) = j;
+ // Determine if the new centroid is duplicate with old ones.
+ bool isDuplicated = false;
+ for (size_t k = 0; k < centroids.n_cols; ++k) {
+ arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
+ if (norm(delta, 2) < duplicateThresh) {
+ isDuplicated = true;
+ assignments(i) = k;
+ break;
+ }
+ }
+
+ if (!isDuplicated) {
+ // this centroid is a new centroid.
+ if (centroids.n_cols == 0) {
+ centroids.insert_cols(0, allCentroids.col(i));
+ assignments(i) = 0;
+ } else {
+ centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
+ assignments(i) = centroids.n_cols - 1;
+ }
+ }
+
+ // Get out of the loop.
break;
}
- }
- if (!isDuplicated) {
-
- // this centroid is a new centroid.
- if (centroids.n_cols == 0) {
- centroids.insert_cols(0, allCentroids.col(i));
- assignments(i) = 0;
- } else {
- centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
- assignments(i) = centroids.n_cols - 1;
- }
}
+
}
}
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index c77f4f7..d25970f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -70,8 +70,7 @@ int main(int argc, char** argv) {
arma::mat centroids;
arma::Col<size_t> assignments;
- kernel::GaussianKernel kernel;
- kernel.Bandwidth(bandwidth);
+ kernel::GaussianKernel kernel(bandwidth);
MeanShift<> meanShift(duplicateThresh,
maxIterations, stopThresh, kernel);
More information about the mlpack-git
mailing list