[mlpack-git] master: generate seeds as initial centroids to speed up (04d234b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 14:43:55 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee384655c4462e422e343e9725437fd772ca4449...182d4a629c1b23f683dff7b284844e4e3e9f5cc4
>---------------------------------------------------------------
commit 04d234b789ca5c748f37d7f45b2fa8f3ed45f86a
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date: Sun Apr 12 00:01:13 2015 +0800
generate seeds as initial centroids to speed up
>---------------------------------------------------------------
04d234b789ca5c748f37d7f45b2fa8f3ed45f86a
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 142 ++++++++++++++--------
1 file changed, 93 insertions(+), 49 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 5d3efbc..002051f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -12,6 +12,9 @@
#include <mlpack/core/metrics/lmetric.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search_stat.hpp>
+#include <mlpack/methods/range_search/range_search.hpp>
+
+#include "map"
// In case it hasn't been included yet.
#include "mean_shift.hpp"
@@ -41,16 +44,16 @@ void MeanShift<KernelType, MatType>::Radius(double radius)
// Estimate radius based on given dataset.
template<typename KernelType, typename MatType>
-double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
+double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data,
+ double ratio)
{
neighbor::AllkNN neighborSearch(data);
-
/**
* For each point in dataset, select nNeighbors nearest points and get
* nNeighbors distances. Use the maximum distance to estimate the duplicate
* threshhold.
*/
- size_t nNeighbors = size_t(data.n_cols * 0.3);
+ size_t nNeighbors = size_t(data.n_cols * ratio);
arma::Mat<size_t> neighbors;
arma::mat distances;
neighborSearch.Search(nNeighbors, neighbors, distances);
@@ -62,31 +65,63 @@ double MeanShift<KernelType, MatType>::EstimateRadius(const MatType &data)
return sum(maxDistances) / (double) data.n_cols;
}
-// General way to calculate the weight of a data point.
-template<typename KernelType, typename MatType>
-bool MeanShift<KernelType, MatType>::CalcWeight(
- const arma::colvec& centroid,
- const arma::colvec& point,
- double& weight)
+// Class to compare two vector
+template <typename VecType>
+class less
{
- double distance = metric::EuclideanDistance::Evaluate(centroid, point);
- if (distance >= radius || distance == 0)
+ public:
+ bool operator()(const VecType& first, const VecType& second) const
+ {
+ for (size_t i = 0; i < first.n_rows; ++i)
+ {
+ if (first[i] == second[i])
+ continue;
+ return first(i) < second(i);
+ }
return false;
+ }
+};
+
+// Generate seeds form given data set
+template<typename KernelType, typename MatType>
+void MeanShift<KernelType, MatType>::genSeeds(
+ const MatType& data,
+ double binSize,
+ int minFreq,
+ MatType& seeds)
+{
+ typedef arma::colvec VecType;
+ std::map<VecType, int, less<VecType> > allSeeds;
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ VecType binnedPoint = arma::floor(data.unsafe_col(i) / binSize);
+ if (allSeeds.find(binnedPoint) == allSeeds.end())
+ allSeeds[binnedPoint] = 1;
+ else
+ allSeeds[binnedPoint]++;
+ }
+
+ // Remove seeds with too few points
+ std::map<VecType, int, less<VecType> >::iterator it;
+ for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
+ {
+ if (it->second >= minFreq)
+ seeds.insert_cols(seeds.n_cols, it->first);
+ }
- distance /= radius;
- weight = kernel.Gradient(distance) / distance;
- return true;
+ seeds = seeds * binSize;
}
/**
- * Perform Mean Shift clustering on the data, returning a list of cluster
+ * Perform Mean Shift clustering on the data set, returning a list of cluster
* assignments and centroids.
*/
template<typename KernelType, typename MatType>
inline void MeanShift<KernelType, MatType>::Cluster(
const MatType& data,
arma::Col<size_t>& assignments,
- arma::mat& centroids)
+ arma::mat& centroids,
+ bool useSeeds)
{
if (radius <= 0)
{
@@ -94,44 +129,59 @@ inline void MeanShift<KernelType, MatType>::Cluster(
Radius(EstimateRadius(data));
}
- // Holds all centroids, before removing duplicate ones.
- arma::mat allCentroids(data.n_rows, data.n_cols);
+ MatType seeds;
+ const MatType* pSeeds = &data;
+ if (useSeeds)
+ {
+ genSeeds(data, radius, 1, seeds);
+ pSeeds = &seeds;
+ }
+
+ // Holds all centroids before removing duplicate ones.
+ arma::mat allCentroids(pSeeds->n_rows, pSeeds->n_cols);
+
assignments.set_size(data.n_cols);
+ range::RangeSearch<> rangeSearcher(data);
+ math::Range validRadius(0, radius);
+ std::vector<std::vector<size_t> > neighbors;
+ std::vector<std::vector<double> > distances;
- // For each point in dataset, perform mean shift algorithm.
- for (size_t i = 0; i < data.n_cols; ++i)
+ // For each seed, perform mean shift algorithm.
+ for (size_t i = 0; i < pSeeds->n_cols; ++i)
{
// Initial centroid is the point itself.
- allCentroids.col(i) = data.col(i);
-
+ allCentroids.col(i) = pSeeds->unsafe_col(i);
for (size_t completedIterations = 0; completedIterations < maxIterations;
completedIterations++)
{
// Store new centroid in this.
- arma::Col<double> newCentroid = arma::zeros(data.n_rows, 1);
+ arma::colvec newCentroid(pSeeds->n_rows, arma::fill::zeros);
double sumWeight = 0;
-
- // Go through all the points
- for (size_t j = 0; j < data.n_cols; ++j)
+ rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
+ neighbors, distances);
+ if (neighbors[0].size() <= 1)
+ break;
+ for (size_t j = 0; j < neighbors[0].size(); ++j)
{
- double weight = 0;
- if (CalcWeight(allCentroids.col(i), data.col(j), weight))
+ if (distances[0][j] > 0)
{
+ distances[0][j] /= radius;
+ double weight = kernel.Gradient(distances[0][j]) / distances[0][j];
sumWeight += weight;
- newCentroid += weight * data.col(j);
+ newCentroid += weight * data.unsafe_col(neighbors[0][j]);
}
}
- newCentroid /= sumWeight;
-
- // Calculate the mean shift vector.
- arma::Col<double> mhVector = newCentroid - allCentroids.col(i);
+ if (sumWeight != 0)
+ newCentroid /= sumWeight;
+ else
+ newCentroid = allCentroids.unsafe_col(i);
// If the mean shift vector is small enough, it has converged.
- if (metric::EuclideanDistance::Evaluate(newCentroid, allCentroids.col(i))
- < 1e-3 * radius)
+ if (metric::EuclideanDistance::Evaluate(newCentroid, allCentroids.unsafe_col(i)) <
+ 1e-3 * radius)
{
// Determine if the new centroid is duplicate with old ones.
bool isDuplicated = false;
@@ -142,25 +192,12 @@ inline void MeanShift<KernelType, MatType>::Cluster(
if (distance < radius)
{
isDuplicated = true;
- assignments(i) = k;
break;
}
}
if (!isDuplicated)
- {
- // This centroid is a new centroid.
- if (centroids.n_cols == 0)
- {
- centroids.insert_cols(0, allCentroids.col(i));
- assignments(i) = 0;
- }
- else
- {
- centroids.insert_cols(centroids.n_cols, allCentroids.col(i));
- assignments(i) = centroids.n_cols - 1;
- }
- }
+ centroids.insert_cols(centroids.n_cols, allCentroids.unsafe_col(i));
// Get out of the loop.
break;
@@ -170,6 +207,13 @@ inline void MeanShift<KernelType, MatType>::Cluster(
allCentroids.col(i) = newCentroid;
}
}
+
+ // Assign centroids to each point
+ neighbor::AllkNN neighborSearcher(centroids, data);
+ arma::mat neighborDistances;
+ arma::Mat<size_t> resultingNeighbors;
+ neighborSearcher.Search(1, resultingNeighbors, neighborDistances);
+ assignments = resultingNeighbors.t();
}
} // namespace meanshift
More information about the mlpack-git
mailing list