[mlpack-svn] r14924 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 18 20:05:12 EDT 2013
Author: rcurtin
Date: 2013-04-18 20:05:11 -0400 (Thu, 18 Apr 2013)
New Revision: 14924
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Allow Bradley-Fayyad initialization for k-means as initialization for EM
algorithm.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2013-04-18 23:38:17 UTC (rev 14923)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp 2013-04-19 00:05:11 UTC (rev 14924)
@@ -4,8 +4,17 @@
*
* This program trains a mixture of Gaussians on a given data matrix.
*/
+#include <mlpack/core.hpp>
+
#include "gmm.hpp"
+#include <mlpack/methods/kmeans/refined_start.hpp>
+
+using namespace mlpack;
+using namespace mlpack::gmm;
+using namespace mlpack::util;
+using namespace mlpack::kmeans;
+
PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
"This program takes a parametric estimate of a Gaussian mixture model (GMM)"
" using the EM algorithm to find the maximum likelihood estimate. The "
@@ -44,9 +53,14 @@
PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
"N", 0);
-using namespace mlpack;
-using namespace mlpack::gmm;
-using namespace mlpack::util;
+// Parameters for k-means initialization.
+PARAM_FLAG("refined_start", "During the initialization, use refined initial "
+ "positions for k-means clustering (Bradley and Fayyad, 1998).", "r");
+PARAM_INT("samplings", "If using --refined_start, specify the number of "
+ "samplings used for initial points.", "S", 100);
+PARAM_DOUBLE("percentage", "If using --refined_start, specify the percentage of"
+ " the dataset used for each sampling (should be between 0.0 and 1.0).",
+ "p", 0.02);
int main(int argc, char* argv[])
{
@@ -84,18 +98,58 @@
const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
const double tolerance = CLI::GetParam<double>("tolerance");
const bool forcePositive = !CLI::HasParam("no_force_positive");
- EMFit<> em(maxIterations, tolerance, forcePositive);
- // Calculate mixture of Gaussians.
- GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
+ // This gets a bit weird because we need different types depending on whether
+ // --refined_start is specified.
+ double likelihood;
+ if (CLI::HasParam("refined_start"))
+ {
+ const int samplings = CLI::GetParam<int>("samplings");
+ const double percentage = CLI::GetParam<double>("percentage");
- // Compute the parameters of the model using the EM algorithm.
- Timer::Start("em");
- double likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
- Timer::Stop("em");
+ if (samplings <= 0)
+ Log::Fatal << "Number of samplings (" << samplings << ") must be greater"
+ << " than 0!" << std::endl;
+ if (percentage <= 0.0 || percentage > 1.0)
+ Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+ << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
+
+ typedef KMeans<metric::SquaredEuclideanDistance, RefinedStart> KMeansType;
+
+ // These are default parameters.
+ KMeansType k(1000, 1.0, metric::SquaredEuclideanDistance(),
+ RefinedStart(samplings, percentage));
+
+ EMFit<KMeansType> em(maxIterations, tolerance, forcePositive, k);
+
+ GMM<EMFit<KMeansType> > gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+ // Compute the parameters of the model using the EM algorithm.
+ Timer::Start("em");
+ likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+ Timer::Stop("em");
+
+ // Save results.
+ gmm.Save(CLI::GetParam<std::string>("output_file"));
+ }
+ else
+ {
+ EMFit<> em(maxIterations, tolerance, forcePositive);
+
+ // Calculate mixture of Gaussians.
+ GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+ // Compute the parameters of the model using the EM algorithm.
+ Timer::Start("em");
+ likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+ Timer::Stop("em");
+
+ // Save results.
+ gmm.Save(CLI::GetParam<std::string>("output_file"));
+ }
+
Log::Info << "Log-likelihood of estimate: " << likelihood << ".\n";
- // Save results.
- gmm.Save(CLI::GetParam<std::string>("output_file"));
+
}
More information about the mlpack-svn
mailing list