[mlpack-git] master: Format fixes for mean shift. (98c0c48)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jun 17 17:02:32 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee7c82dba945db7c5469485c61d626eb0a4629b0...98c0c483a3547c8f49cdfe38670a603bd29036a0
>---------------------------------------------------------------
commit 98c0c483a3547c8f49cdfe38670a603bd29036a0
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jun 17 16:54:11 2015 -0400
Format fixes for mean shift.
>---------------------------------------------------------------
98c0c483a3547c8f49cdfe38670a603bd29036a0
src/mlpack/methods/mean_shift/mean_shift.hpp | 58 +++++++++++++----------
src/mlpack/methods/mean_shift/mean_shift_impl.hpp | 24 +++++-----
2 files changed, 44 insertions(+), 38 deletions(-)
diff --git a/src/mlpack/methods/mean_shift/mean_shift.hpp b/src/mlpack/methods/mean_shift/mean_shift.hpp
index 1bd0d8b..55db83f 100644
--- a/src/mlpack/methods/mean_shift/mean_shift.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift.hpp
@@ -15,18 +15,17 @@
#include <boost/utility.hpp>
namespace mlpack {
-namespace meanshift /** Mean Shift clustering. */ {
+namespace meanshift /** Mean shift clustering. */ {
/**
- * This class implements Mean Shift clustering.
- * For each point in dataset, apply mean shift algorithm until maximum
- * iterations or convergence.
- * Then remove duplicate centroids.
+ * This class implements mean shift clustering. For each point in dataset,
+ * apply mean shift algorithm until maximum iterations or convergence. Then
+ * remove duplicate centroids.
*
- * A simple example of how to run Mean Shift clustering is shown below.
+ * A simple example of how to run mean shift clustering is shown below.
*
* @code
- * extern arma::mat data; // Dataset we want to run Mean Shift on.
+ * extern arma::mat data; // Dataset we want to run mean shift on.
* arma::Col<size_t> assignments; // Cluster assignments.
* arma::mat centroids; // Cluster centroids.
*
@@ -34,11 +33,11 @@ namespace meanshift /** Mean Shift clustering. */ {
* meanShift.Cluster(dataset, assignments, centroids);
* @endcode
*
- * @tparam UseKernel Use kernel or mean to calculate new centroid
+ * @tparam UseKernel Use kernel or mean to calculate new centroid.
* If false, KernelType will be ignored.
- * @tparam KernelType the kernel to use.
+ * @tparam KernelType The kernel to use.
+ * @tparam MatType The type of matrix the data is stored in.
*/
-
template<bool UseKernel = false,
typename KernelType = kernel::GaussianKernel,
typename MatType = arma::mat>
@@ -46,28 +45,30 @@ class MeanShift
{
public:
/**
- * Create a Mean Shift object and set the parameters which Mean Shift
- * will be run with.
+ * Create a mean shift object and set the parameters which mean shift will be
+ * run with.
*
- * @param radius If distance of two centroids is less than it, one will be removed. If this value isn't positive, an estimation will be given when clustering.
+ * @param radius If distance of two centroids is less than it, one will be
+ * removed. If this value isn't positive, an estimation will be given
+ * when clustering.
* @param maxIterations Maximum number of iterations allowed before giving up
- * iterations will terminate.
+ * iterations will terminate.
* @param kernel Optional KernelType object.
*/
MeanShift(const double radius = 0,
const size_t maxIterations = 1000,
const KernelType kernel = KernelType());
-
/**
* Give an estimation of radius based on given dataset.
+ *
* @param data Dataset for estimation.
- * @param ratio How many neighbors to use
+ * @param ratio Percentage of dataset to use for nearest neighbor search.
*/
- double EstimateRadius(const MatType& data, double ratio = 0.2);
+ double EstimateRadius(const MatType& data, const double ratio = 0.2);
/**
- * Perform Mean Shift clustering on the data, returning a list of cluster
+ * Perform mean shift clustering on the data, returning a list of cluster
* assignments and centroids.
*
* @tparam MatType Type of matrix.
@@ -96,17 +97,23 @@ class MeanShift
KernelType& Kernel() { return kernel; }
private:
-
/**
* To speed up, we can generate some seeds from data set and use
- * them as initial centroids rather than all the points in the data set.
+ * them as initial centroids rather than all the points in the data set. The
+ * basic idea here is that we will place our points into hypercube bins of
+ * side length binSize, and any bins that contain fewer than minFreq points
+ * will be removed as possible seeds. Usually, 1 is a sufficient parameter
+ * for minFreq, and the bin size can be set equal to the estimated radius.
*
- * @param data The reference data set
- * @param binSize It can be set equal to the estimated radius
- * @param minFreq Usually 1 is enough
- * @param seed Store generated sedds
+ * @param data The reference data set.
+ * @param binSize Width of hypercube bins.
+ * @param minFreq Minimum number of points in bin.
+ * @param seed Matrix to store generated seeds in.
*/
- void GenSeeds(const MatType& data, double binSize, int minFreq, MatType& seeds);
+ void GenSeeds(const MatType& data,
+ const double binSize,
+ const int minFreq,
+ MatType& seeds);
/**
* Use kernel to calculate new centroid given dataset and valid neighbors.
@@ -150,7 +157,6 @@ class MeanShift
//! Instantiated kernel.
KernelType kernel;
-
};
} // namespace meanshift
diff --git a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
index aafedd1..634dcaf 100644
--- a/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
+++ b/src/mlpack/methods/mean_shift/mean_shift_impl.hpp
@@ -49,12 +49,13 @@ double MeanShift<UseKernel, KernelType, MatType>::
EstimateRadius(const MatType& data, double ratio)
{
neighbor::AllkNN neighborSearch(data);
+
/**
* For each point in dataset, select nNeighbors nearest points and get
* nNeighbors distances. Use the maximum distance to estimate the duplicate
* threshhold.
*/
- size_t nNeighbors = size_t(data.n_cols * ratio);
+ const size_t nNeighbors = size_t(data.n_cols * ratio);
arma::Mat<size_t> neighbors;
arma::mat distances;
neighborSearch.Search(nNeighbors, neighbors, distances);
@@ -66,7 +67,7 @@ EstimateRadius(const MatType& data, double ratio)
return sum(maxDistances) / (double) data.n_cols;
}
-// Class to compare two vector
+// Class to compare two vectors.
template <typename VecType>
class less
{
@@ -83,12 +84,12 @@ class less
}
};
-// Generate seeds from given data set
+// Generate seeds from given data set.
template<bool UseKernel, typename KernelType, typename MatType>
void MeanShift<UseKernel, KernelType, MatType>::GenSeeds(
const MatType& data,
- double binSize,
- int minFreq,
+ const double binSize,
+ const int minFreq,
MatType& seeds)
{
typedef arma::colvec VecType;
@@ -102,7 +103,7 @@ void MeanShift<UseKernel, KernelType, MatType>::GenSeeds(
allSeeds[binnedPoint]++;
}
- // Remove seeds with too few points
+ // Remove seeds with too few points.
std::map<VecType, int, less<VecType> >::iterator it;
for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
{
@@ -134,6 +135,7 @@ CalculateCentroid(const MatType& data,
centroid += weight * data.unsafe_col(neighbors[i]);
}
}
+
if (sumWeight != 0)
{
centroid /= sumWeight;
@@ -153,14 +155,12 @@ CalculateCentroid(const MatType& data,
arma::colvec& centroid)
{
for (size_t i = 0; i < neighbors.size(); ++i)
- {
centroid += data.unsafe_col(neighbors[i]);
- }
+
centroid /= neighbors.size();
return true;
}
-
/**
* Perform Mean Shift clustering on the data set, returning a list of cluster
* assignments and centroids.
@@ -217,8 +217,8 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
newCentroid = allCentroids.unsafe_col(i);
// If the mean shift vector is small enough, it has converged.
- if (metric::EuclideanDistance::Evaluate(newCentroid, allCentroids.unsafe_col(i)) <
- 1e-3 * radius)
+ if (metric::EuclideanDistance::Evaluate(newCentroid,
+ allCentroids.unsafe_col(i)) < 1e-3 * radius)
{
// Determine if the new centroid is duplicate with old ones.
bool isDuplicated = false;
@@ -245,7 +245,7 @@ inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
}
}
- // Assign centroids to each point
+ // Assign centroids to each point.
neighbor::AllkNN neighborSearcher(centroids);
arma::mat neighborDistances;
arma::Mat<size_t> resultingNeighbors;
More information about the mlpack-git
mailing list