[mlpack-svn] r12065 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Mar 27 12:35:37 EDT 2012


Author: rcurtin
Date: 2012-03-27 12:35:37 -0400 (Tue, 27 Mar 2012)
New Revision: 12065

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
   mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
Log:
Add a Classify() method.


Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2012-03-27 16:08:59 UTC (rev 12064)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp	2012-03-27 16:35:37 UTC (rev 12065)
@@ -276,6 +276,34 @@
   return;
 }
 
+/**
+ * Classify the given observations as being from an individual component in this
+ * GMM.
+ */
+void GMM::Classify(const arma::mat& observations,
+                   arma::Col<size_t>& labels) const
+{
+  // This is not the best way to do this!
+
+  // We should not have to fill this with values, because each one should be
+  // overwritten.
+  labels.set_size(observations.n_cols);
+  for (size_t i = 0; i < observations.n_cols; ++i)
+  {
+    // Find maximum probability component.
+    double probability = 0;
+    for (size_t j = 0; j < gaussians; ++j)
+    {
+      double newProb = Probability(observations.unsafe_col(i), j);
+      if (newProb >= probability)
+      {
+        probability = newProb;
+        labels[i] = j;
+      }
+    }
+  }
+}
+
 double GMM::Loglikelihood(const arma::mat& data,
                           const std::vector<arma::vec>& meansL,
                           const std::vector<arma::mat>& covariancesL,

Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2012-03-27 16:08:59 UTC (rev 12064)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp	2012-03-27 16:35:37 UTC (rev 12065)
@@ -162,6 +162,25 @@
   void Estimate(const arma::mat& observations,
                 const arma::vec& probabilities);
 
+  /**
+   * Classify the given observations as being from an individual component in
+   * this GMM.  The resultant classifications are stored in the 'labels' object,
+   * and each label will be between 0 and (Gaussians() - 1).  Supposing that a
+   * point was classified with label 2, and that our GMM object was called
+   * 'gmm', one could access the relevant Gaussian distribution as follows:
+   *
+   * @code
+   * arma::vec mean = gmm.Means()[2];
+   * arma::mat covariance = gmm.Covariances()[2];
+   * double priorWeight = gmm.Weights()[2];
+   * @endcode
+   *
+   * @param observations List of observations to classify.
+   * @param labels Object which will be filled with labels.
+   */
+  void Classify(const arma::mat& observations,
+                arma::Col<size_t>& labels) const;
+
  private:
   /**
    * This function computes the loglikelihood of the given model.  This function




More information about the mlpack-svn mailing list