[mlpack-svn] r14924 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Apr 18 20:05:12 EDT 2013


Author: rcurtin
Date: 2013-04-18 20:05:11 -0400 (Thu, 18 Apr 2013)
New Revision: 14924

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
Log:
Allow Bradley-Fayyad initialization for k-means as initialization for EM
algorithm.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2013-04-18 23:38:17 UTC (rev 14923)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm_main.cpp	2013-04-19 00:05:11 UTC (rev 14924)
@@ -4,8 +4,17 @@
  *
  * This program trains a mixture of Gaussians on a given data matrix.
  */
+#include <mlpack/core.hpp>
+
 #include "gmm.hpp"
 
+#include <mlpack/methods/kmeans/refined_start.hpp>
+
+using namespace mlpack;
+using namespace mlpack::gmm;
+using namespace mlpack::util;
+using namespace mlpack::kmeans;
+
 PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
     "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
     " using the EM algorithm to find the maximum likelihood estimate.  The "
@@ -44,9 +53,14 @@
 PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
     "N", 0);
 
-using namespace mlpack;
-using namespace mlpack::gmm;
-using namespace mlpack::util;
+// Parameters for k-means initialization.
+PARAM_FLAG("refined_start", "During the initialization, use refined initial "
+    "positions for k-means clustering (Bradley and Fayyad, 1998).", "r");
+PARAM_INT("samplings", "If using --refined_start, specify the number of "
+    "samplings used for initial points.", "S", 100);
+PARAM_DOUBLE("percentage", "If using --refined_start, specify the percentage of"
+    " the dataset used for each sampling (should be between 0.0 and 1.0).",
+    "p", 0.02);
 
 int main(int argc, char* argv[])
 {
@@ -84,18 +98,58 @@
   const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
   const double tolerance = CLI::GetParam<double>("tolerance");
   const bool forcePositive = !CLI::HasParam("no_force_positive");
-  EMFit<> em(maxIterations, tolerance, forcePositive);
 
-  // Calculate mixture of Gaussians.
-  GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
+  // This gets a bit weird because we need different types depending on whether
+  // --refined_start is specified.
+  double likelihood;
+  if (CLI::HasParam("refined_start"))
+  {
+    const int samplings = CLI::GetParam<int>("samplings");
+    const double percentage = CLI::GetParam<double>("percentage");
 
-  // Compute the parameters of the model using the EM algorithm.
-  Timer::Start("em");
-  double likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
-  Timer::Stop("em");
+    if (samplings <= 0)
+      Log::Fatal << "Number of samplings (" << samplings << ") must be greater"
+          << " than 0!" << std::endl;
 
+    if (percentage <= 0.0 || percentage > 1.0)
+      Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+          << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
+
+    typedef KMeans<metric::SquaredEuclideanDistance, RefinedStart> KMeansType;
+
+    // These are default parameters.
+    KMeansType k(1000, 1.0, metric::SquaredEuclideanDistance(),
+        RefinedStart(samplings, percentage));
+
+    EMFit<KMeansType> em(maxIterations, tolerance, forcePositive, k);
+
+    GMM<EMFit<KMeansType> > gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+    // Compute the parameters of the model using the EM algorithm.
+    Timer::Start("em");
+    likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+    Timer::Stop("em");
+
+    // Save results.
+    gmm.Save(CLI::GetParam<std::string>("output_file"));
+  }
+  else
+  {
+    EMFit<> em(maxIterations, tolerance, forcePositive);
+
+    // Calculate mixture of Gaussians.
+    GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+    // Compute the parameters of the model using the EM algorithm.
+    Timer::Start("em");
+    likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+    Timer::Stop("em");
+
+    // Save results.
+    gmm.Save(CLI::GetParam<std::string>("output_file"));
+  }
+
   Log::Info << "Log-likelihood of estimate: " << likelihood << ".\n";
 
-  // Save results.
-  gmm.Save(CLI::GetParam<std::string>("output_file"));
+
 }




More information about the mlpack-svn mailing list