[mlpack-git] master: Fix merge failures. (d0a3e9a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 29 15:00:50 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/182d4a629c1b23f683dff7b284844e4e3e9f5cc4...e06ce8f3ac1170108c20c114e82cae10356d1301
>---------------------------------------------------------------
commit d0a3e9a3ae475ebc758e7309ba8c87a3b0b4d357
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Apr 29 14:57:50 2015 -0400
Fix merge failures.
>---------------------------------------------------------------
d0a3e9a3ae475ebc758e7309ba8c87a3b0b4d357
src/mlpack/methods/mean_shift/mean_shift.hpp | 55 ++++++++++++++++++-----
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 9 ++--
src/mlpack/methods/mean_shift/mean_shift_main.cpp | 5 +--
3 files changed, 52 insertions(+), 17 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 271632e..1bd0d8b 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -34,10 +34,13 @@ namespace meanshift /** Mean Shift clustering. */ {
* meanShift.Cluster(dataset, assignments, centroids);
* @endcode
*
+ * @tparam UseKernel Use kernel or mean to calculate new centroid
+ * If false, KernelType will be ignored.
* @tparam KernelType the kernel to use.
*/
-template<typename KernelType = kernel::GaussianKernel,
+template<bool UseKernel = false,
+ typename KernelType = kernel::GaussianKernel,
typename MatType = arma::mat>
class MeanShift
{
@@ -59,8 +62,9 @@ class MeanShift
/**
* Give an estimation of radius based on given dataset.
* @param data Dataset for estimation.
+ * @param ratio How many neighbors to use
*/
- double EstimateRadius(const MatType& data);
+ double EstimateRadius(const MatType& data, double ratio = 0.2);
/**
* Perform Mean Shift clustering on the data, returning a list of cluster
@@ -73,7 +77,8 @@ class MeanShift
*/
void Cluster(const MatType& data,
arma::Col<size_t>& assignments,
- arma::mat& centroids);
+ arma::mat& centroids,
+ bool useSeeds = true);
//! Get the maximum number of iterations.
size_t MaxIterations() const { return maxIterations; }
@@ -93,15 +98,45 @@ class MeanShift
private:
/**
- * A general approach to calculate the weight for a point.
+ * To speed up, we can generate some seeds from data set and use
+ * them as initial centroids rather than all the points in the data set.
*
- * @param centroid The centroid to calculate the weight
- * @param point Calculate its weight
- * @param weight Store the weight
- * @return If true, the @point is near enough to the @centroid and @weight is valid,
- * If false, the @point is far from the @centroid and @weight is invalid.
+ * @param data The reference data set
+ * @param binSize It can be set equal to the estimated radius
+ * @param minFreq Usually 1 is enough
+ * @param seed Store generated sedds
*/
- bool CalcWeight(const arma::colvec& centroid, const arma::colvec& point, double& weight);
+ void GenSeeds(const MatType& data, double binSize, int minFreq, MatType& seeds);
+
+ /**
+ * Use kernel to calculate new centroid given dataset and valid neighbors.
+ *
+ * @param data The whole dataset
+ * @param neighbors Valid neighbors
+ * @param distances Distances to neighbors
+ # @param centroid Store calculated centroid
+ */
+ template<bool ApplyKernel = UseKernel>
+ typename std::enable_if<ApplyKernel, bool>::type
+ CalculateCentroid(const MatType& data,
+ const std::vector<size_t>& neighbors,
+ const std::vector<double>& distances,
+ arma::colvec& centroid);
+
+ /**
+ * Use mean to calculate new centroid given dataset and valid neighbors.
+ *
+ * @param data The whole dataset
+ * @param neighbors Valid neighbors
+ * @param distances Distances to neighbors
+ # @param centroid Store calculated centroid
+ */
+ template<bool ApplyKernel = UseKernel>
+ typename std::enable_if<!ApplyKernel, bool>::type
+ CalculateCentroid(const MatType& data,
+ const std::vector<size_t>& neighbors,
+ const std::vector<double>&, /*unused*/
+ arma::colvec& centroid);
/**
* If distance of two centroids is less than radius, one will be removed.
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index 0edf692..b3d3b2f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -174,7 +174,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
{
if (radius <= 0)
{
- // An invalid radius is given, an estimation is needed.
+ // An invalid radius is given; an estimation is needed.
Radius(EstimateRadius(data));
}
@@ -199,7 +199,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
// For each seed, perform mean shift algorithm.
for (size_t i = 0; i < pSeeds->n_cols; ++i)
{
- // Initial centroid is the point itself.
+ // Initial centroid is the seed itself.
allCentroids.col(i) = pSeeds->unsafe_col(i);
for (size_t completedIterations = 0; completedIterations < maxIterations;
completedIterations++)
@@ -209,6 +209,9 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
neighbors, distances);
+ if (neighbors[0].size() <= 1)
+ break;
+
// Calculate new centroid.
if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
newCentroid = allCentroids.unsafe_col(i);
@@ -222,7 +225,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
for (size_t k = 0; k < centroids.n_cols; ++k)
{
const double distance = metric::EuclideanDistance::Evaluate(
- allCentroids.col(i), centroids.col(k));
+ allCentroids.unsafe_col(i), centroids.unsafe_col(k));
if (distance < radius)
{
isDuplicated = true;
diff --git a/src/mlpack/methods/mean_shift/mean_shift_main.cpp b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
index 109e3b5..ea323d9 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_main.cpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_main.cpp
@@ -37,7 +37,6 @@ PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
PARAM_INT("max_iterations", "Maximum number of iterations before Mean Shift "
"terminates.", "m", 1000);
-PARAM_DOUBLE("bandwidth", "bandwidth of Gaussian kernel ", "b", 1.0);
PARAM_DOUBLE("radius", "If distance of two centroids is less than radius "
"one will be removed. "
"If it isn't positive, an estimation will be given. "
@@ -71,9 +70,7 @@ int main(int argc, char** argv)
arma::mat centroids;
arma::Col<size_t> assignments;
- GaussianKernel kernel(bandwidth);
-
- MeanShift<> meanShift(radius, maxIterations, kernel);
+ MeanShift<> meanShift(radius, maxIterations);
Timer::Start("clustering");
Log::Info << "Performing mean shift clustering..." << endl;
More information about the mlpack-git
mailing list