[mlpack-git] master: involve mean to calculate new centroid (6511938)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:59 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 651193842da1aaa98be89f10f40ab100b5188697
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Wed Apr 22 15:21:34 2015 +0800
involve mean to calculate new centroid
>---------------------------------------------------------------
651193842da1aaa98be89f10f40ab100b5188697
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 98 +++++++++++++++--------
1 file changed, 66 insertions(+), 32 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 002051f..0edf692 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -25,10 +25,11 @@ namespace meanshift {
/**
* Construct the Mean Shift object.
*/
-template<typename KernelType, typename MatType>
-MeanShift<KernelType, MatType>::MeanShift(const double radius,
- const size_t maxIterations,
- const KernelType kernel) :
+template<bool UseKernel, typename KernelType, typename MatType>
+MeanShift<UseKernel, KernelType, MatType>::
+MeanShift(const double radius,
+ const size_t maxIterations,
+ const KernelType kernel) :
radius(radius),
maxIterations(maxIterations),
kernel(kernel)
@@ -36,16 +37,16 @@ MeanShift<KernelType, MatType>::MeanShift(const double radius,
// Nothing to do.
}
-template<typename KernelType, typename MatType>
-void MeanShift<KernelType, MatType>::Radius(double radius)
+template<bool UseKernel, typename KernelType, typename MatType>
+void MeanShift<UseKernel, KernelType, MatType>::Radius(double radius)
{
this->radius = radius;
}
// Estimate radius based on given dataset.
-template<typename KernelType, typename MatType>
-double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data,
- double ratio)
+template<bool UseKernel, typename KernelType, typename MatType>
+double MeanShift<UseKernel, KernelType, MatType>::
+EstimateRadius(const MatType& data, double ratio)
{
neighbor::AllkNN neighborSearch(data);
/**
@@ -82,9 +83,9 @@ class less
}
};
-// Generate seeds form given data set
-template<typename KernelType, typename MatType>
-void MeanShift<KernelType, MatType>::genSeeds(
+// Generate seeds from given data set
+template<bool UseKernel, typename KernelType, typename MatType>
+void MeanShift<UseKernel, KernelType, MatType>::GenSeeds(
const MatType& data,
double binSize,
int minFreq,
@@ -112,12 +113,60 @@ void MeanShift<KernelType, MatType>::genSeeds(
seeds = seeds * binSize;
}
+// Calculate new centroid with given kernel.
+template<bool UseKernel, typename KernelType, typename MatType>
+template<bool ApplyKernel>
+typename std::enable_if<ApplyKernel, bool>::type
+MeanShift<UseKernel, KernelType, MatType>::
+CalculateCentroid(const MatType& data,
+ const std::vector<size_t>& neighbors,
+ const std::vector<double>& distances,
+ arma::colvec& centroid)
+{
+ double sumWeight = 0;
+ for (size_t i = 0; i < neighbors.size(); ++i)
+ {
+ if (distances[i] > 0)
+ {
+ double dist = distances[i] / radius;
+ double weight = kernel.Gradient(dist) / dist;
+ sumWeight += weight;
+ centroid += weight * data.unsafe_col(neighbors[i]);
+ }
+ }
+ if (sumWeight != 0)
+ {
+ centroid /= sumWeight;
+ return true;
+ }
+ return false;
+}
+
+// Calculate new centroid by mean.
+template<bool UseKernel, typename KernelType, typename MatType>
+template<bool ApplyKernel>
+typename std::enable_if<!ApplyKernel, bool>::type
+MeanShift<UseKernel, KernelType, MatType>::
+CalculateCentroid(const MatType& data,
+ const std::vector<size_t>& neighbors,
+ const std::vector<double>&, /*unused*/
+ arma::colvec& centroid)
+{
+ for (size_t i = 0; i < neighbors.size(); ++i)
+ {
+ centroid += data.unsafe_col(neighbors[i]);
+ }
+ centroid /= neighbors.size();
+ return true;
+}
+
+
/**
* Perform Mean Shift clustering on the data set, returning a list of cluster
* assignments and centroids.
*/
-template<typename KernelType, typename MatType>
-inline void MeanShift<KernelType, MatType>::Cluster(
+template<bool UseKernel, typename KernelType, typename MatType>
+inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
const MatType& data,
arma::Col<size_t>& assignments,
arma::mat& centroids,
@@ -133,7 +182,7 @@ inline void MeanShift<KernelType, MatType>::Cluster(
const MatType* pSeeds = &data;
if (useSeeds)
{
- genSeeds(data, radius, 1, seeds);
+ GenSeeds(data, radius, 1, seeds);
pSeeds = &seeds;
}
@@ -158,25 +207,10 @@ inline void MeanShift<KernelType, MatType>::Cluster(
// Store new centroid in this.
arma::colvec newCentroid(pSeeds->n_rows, arma::fill::zeros);
- double sumWeight = 0;
rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
neighbors, distances);
- if (neighbors[0].size() <= 1)
- break;
- for (size_t j = 0; j < neighbors[0].size(); ++j)
- {
- if (distances[0][j] > 0)
- {
- distances[0][j] /= radius;
- double weight = kernel.Gradient(distances[0][j]) / distances[0][j];
- sumWeight += weight;
- newCentroid += weight * data.unsafe_col(neighbors[0][j]);
- }
- }
-
- if (sumWeight != 0)
- newCentroid /= sumWeight;
- else
+ // Calculate new centroid.
+ if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
newCentroid = allCentroids.unsafe_col(i);
// If the mean shift vector is small enough, it has converged.
More information about the mlpack-git
mailing list