[mlpack-git] master: Style overhaul, and clarify some comments. (481aa09)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:49 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 481aa091d3315e5f1794dd316467e4ce6c05a332
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Mar 27 17:13:02 2015 +0000
Style overhaul, and clarify some comments.
>---------------------------------------------------------------
481aa091d3315e5f1794dd316467e4ce6c05a332
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 147 ++++++++++------------
1 file changed, 67 insertions(+), 80 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index b769190..4584ed0 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -22,50 +22,40 @@ namespace meanshift {
/**
* Construct the Mean Shift object.
*/
-template<typename KernelType,
- typename MatType>
-MeanShift<
- KernelType,
- MatType>::
-MeanShift(const double radius,
- const size_t maxIterations,
- const KernelType kernel) :
+template<typename KernelType, typename MatType>
+MeanShift<KernelType, MatType>::MeanShift(const double radius,
+ const size_t maxIterations,
+ const KernelType kernel) :
maxIterations(maxIterations),
- kernel(kernel)
+ kernel(kernel),
{
+ // Set the radius; estimate if needed.
Radius(radius);
}
-template<typename KernelType,
- typename MatType>
-void MeanShift<
- KernelType,
- MatType>::
-Radius(double radius) {
+template<typename KernelType, typename MatType>
+void MeanShift<KernelType, MatType>::Radius(double radius)
+{
this->radius = radius;
}
// Estimate radius based on given dataset.
-template<typename KernelType,
- typename MatType>
-double MeanShift<
- KernelType,
- MatType>::
-EstimateRadius(const MatType &data) {
-
+template<typename KernelType, typename MatType>
+double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
+{
neighbor::NeighborSearch<
- neighbor::NearestNeighborSort,
- metric::EuclideanDistance,
- tree::BinarySpaceTree<bound::HRectBound<2>,
- neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
- > neighborSearch(data);
+ neighbor::NearestNeighborSort,
+ metric::EuclideanDistance,
+ tree::BinarySpaceTree<bound::HRectBound<2>,
+ neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> >
+ > neighborSearch(data);
/**
- * For each point in dataset,
- * select nNeighbors nearest points and get nNeighbors distances.
- * Use the maximum distance to estimate the duplicate thresh.
+ * For each point in dataset, select nNeighbors nearest points and get
+ * nNeighbors distances. Use the maximum distance to estimate the duplicate
+ * threshhold.
*/
- size_t nNeighbors = (int)(data.n_cols * 0.3);
+ size_t nNeighbors = size_t(data.n_cols * 0.3);
arma::Mat<size_t> neighbors;
arma::mat distances;
neighborSearch.Search(nNeighbors, neighbors, distances);
@@ -74,102 +64,103 @@ EstimateRadius(const MatType &data) {
arma::rowvec maxDistances = max(distances);
// Calculate and return the radius.
- return sum(maxDistances) / (double)data.n_cols;
-
+ return sum(maxDistances) / (double) data.n_cols;
}
// General way to calculate the weight of a data point.
-template <typename KernelType,
- typename MatType>
-bool
-MeanShift<
- KernelType,
- MatType>::
-CalcWeight(const arma::colvec& centroid, const arma::colvec& point,
- double& weight) {
-
+template<typename KernelType, typename MatType>
+bool MeanShift<KernelType, MatType>::CalcWeight(
+ const arma::colvec& centroid,
+ const arma::colvec& point,
+ double& weight)
+{
double distance = EuclideanDistance::Evaluate(centroid, point);
- if (distance >= radius || distance == 0) {
+ if (distance >= radius || distance == 0)
return false;
- }
+
distance /= radius;
weight = kernel.Gradient(distance) / distance;
return true;
-
}
/**
* Perform Mean Shift clustering on the data, returning a list of cluster
* assignments and centroids.
*/
-template<typename KernelType,
- typename MatType>
-inline void MeanShift<
- KernelType,
- MatType>::
-Cluster(const MatType& data,
- arma::Col<size_t>& assignments,
- arma::mat& centroids) {
-
- if (radius <= 0) {
+template<typename KernelType, typename MatType>
+inline void MeanShift<KernelType, MatType>::Cluster(
+ const MatType& data,
+ arma::Col<size_t>& assignments,
+ arma::mat& centroids)
+{
+ if (radius <= 0)
+ {
// An invalid radius is given, an estimation is needed.
Radius(EstimateRadius(data));
}
- // all centroids before remove duplicate ones.
+ // Holds all centroids, before removing duplicate ones.
arma::mat allCentroids(data.n_rows, data.n_cols);
assignments.set_size(data.n_cols);
- // for each point in dataset, perform mean shift algorithm.
- for (size_t i = 0; i < data.n_cols; ++i) {
- //initial centroid is the point itself.
+ // For each point in dataset, perform mean shift algorithm.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Initial centroid is the point itself.
allCentroids.col(i) = data.col(i);
for (size_t completedIterations = 0; completedIterations < maxIterations;
- completedIterations++) {
-
- // to store new centroid
+ completedIterations++)
+ {
+ // Store new centroid in this.
arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
double sumWeight = 0;
// Go through all the points
- for (size_t j = 0; j < data.n_cols; ++j) {
-
+ for (size_t j = 0; j < data.n_cols; ++j)
+ {
double weight = 0;
- if (CalcWeight(allCentroids.col(i), data.col(j), weight)) {
+ if (CalcWeight(allCentroids.col(i), data.col(j), weight))
+ {
sumWeight += weight;
newCentroid += weight * data.col(j);
}
-
}
newCentroid /= sumWeight;
- // calc the mean shift vector.
+ // Calculate the mean shift vector.
arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
// If the mean shift vector is small enough, it has converged.
- if (EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i)) < 1e-3 * radius) {
-
+ if (EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i)) <
+ 1e-3 * radius)
+ {
// Determine if the new centroid is duplicate with old ones.
bool isDuplicated = false;
- for (size_t k = 0; k < centroids.n_cols; ++k) {
+ for (size_t k = 0; k < centroids.n_cols; ++k)
+ {
arma::Col<double> delta = allCentroids.col(i) - centroids.col(k);
- if (norm(delta, 2) < radius) {
+ if (norm(delta, 2) < radius)
+ {
isDuplicated = true;
assignments(i) = k;
break;
}
}
- if (!isDuplicated) {
- // this centroid is a new centroid.
- if (centroids.n_cols == 0) {
+ if (!isDuplicated)
+ {
+ // This centroid is a new centroid.
+ if (centroids.n_cols == 0)
+ {
centroids.insert_cols(0, allCentroids.col(i));
assignments(i) = 0;
- } else {
+ }
+ else
+ {
centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
assignments(i) = centroids.n_cols - 1;
}
@@ -179,14 +170,10 @@ Cluster(const MatType& data,
break;
}
-
- // update the centroid.
+ // Update the centroid.
allCentroids.col(i) = newCentroid;
-
}
-
}
-
}
} // namespace meanshift
More information about the mlpack-git
mailing list