[mlpack-svn] r12065 - mlpack/trunk/src/mlpack/methods/gmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Mar 27 12:35:37 EDT 2012
Author: rcurtin
Date: 2012-03-27 12:35:37 -0400 (Tue, 27 Mar 2012)
New Revision: 12065
Modified:
mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
Log:
Add a Classify() method.
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp 2012-03-27 16:08:59 UTC (rev 12064)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.cpp 2012-03-27 16:35:37 UTC (rev 12065)
@@ -276,6 +276,34 @@
return;
}
+/**
+ * Classify the given observations as being from an individual component in this
+ * GMM.
+ */
+void GMM::Classify(const arma::mat& observations,
+ arma::Col<size_t>& labels) const
+{
+ // This is not the best way to do this!
+
+ // We should not have to fill this with values, because each one should be
+ // overwritten.
+ labels.set_size(observations.n_cols);
+ for (size_t i = 0; i < observations.n_cols; ++i)
+ {
+ // Find maximum probability component.
+ double probability = 0;
+ for (size_t j = 0; j < gaussians; ++j)
+ {
+ double newProb = Probability(observations.unsafe_col(i), j);
+ if (newProb >= probability)
+ {
+ probability = newProb;
+ labels[i] = j;
+ }
+ }
+ }
+}
+
double GMM::Loglikelihood(const arma::mat& data,
const std::vector<arma::vec>& meansL,
const std::vector<arma::mat>& covariancesL,
Modified: mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp 2012-03-27 16:08:59 UTC (rev 12064)
+++ mlpack/trunk/src/mlpack/methods/gmm/gmm.hpp 2012-03-27 16:35:37 UTC (rev 12065)
@@ -162,6 +162,25 @@
void Estimate(const arma::mat& observations,
const arma::vec& probabilities);
+ /**
+ * Classify the given observations as being from an individual component in
+ * this GMM. The resultant classifications are stored in the 'labels' object,
+ * and each label will be between 0 and (Gaussians() - 1). Supposing that a
+ * point was classified with label 2, and that our GMM object was called
+ * 'gmm', one could access the relevant Gaussian distribution as follows:
+ *
+ * @code
+ * arma::vec mean = gmm.Means()[2];
+ * arma::mat covariance = gmm.Covariances()[2];
+ * double priorWeight = gmm.Weights()[2];
+ * @endcode
+ *
+ * @param observations List of observations to classify.
+ * @param labels Object which will be filled with labels.
+ */
+ void Classify(const arma::mat& observations,
+ arma::Col<size_t>& labels) const;
+
private:
/**
* This function computes the loglikelihood of the given model. This function
More information about the mlpack-svn
mailing list