[mlpack-svn] r10307 - mlpack/trunk/src/mlpack/methods/gmm

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 16 22:29:03 EST 2011


Author: rcurtin
Date: 2011-11-16 22:29:03 -0500 (Wed, 16 Nov 2011)
New Revision: 10307

Modified:
   mlpack/trunk/src/mlpack/methods/gmm/phi.hpp
Log:
Add an overload of phi() which can do multiple data points at once; this way, we
can avoid recalculating the inverses of the covariances over and over and over
again.


Modified: mlpack/trunk/src/mlpack/methods/gmm/phi.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/gmm/phi.hpp	2011-11-16 22:54:11 UTC (rev 10306)
+++ mlpack/trunk/src/mlpack/methods/gmm/phi.hpp	2011-11-17 03:29:03 UTC (rev 10307)
@@ -92,12 +92,38 @@
     arma::mat inv_d = cinv * d_cov[i];
 
     g_cov[i] = f * dot(d_cov[i] * invDiff, invDiff) +
-        accu((cinv * d_cov[i]).diag()) / 2;
+        accu(inv_d.diag()) / 2;
   }
 
   return f;
 }
 
+/**
+ * Calculates the multivariate Gaussian probability density function for each
+ * data point (column) in the given matrix, with respect to the given mean and
+ * variance.
+ */
+inline void phi(const arma::mat& x,
+                const arma::vec& mean,
+                const arma::mat& cov,
+                arma::vec& probabilities)
+{
+  // Column i of 'diffs' is the difference between x.col(i) and the mean.
+  arma::mat diffs = x - (mean * arma::ones<arma::rowvec>(x.n_elem));
+
+  // Now, we only want to calculate the diagonal elements of (diffs' * cov^-1 *
+  // diffs).  We just don't need any of the other elements.  We can calculate
+  // the right hand part of the equation (instead of the left side) so that
+  // later we are referencing columns, not rows -- that is faster.
+  arma::mat rhs = -0.5 * inv(cov) * diffs;
+  arma::vec exponents(x.n_cols); // We will now fill this.
+  for (size_t i = 0; i < x.n_cols; i++)
+    exponents(i) = accu(diffs.col(i) % rhs.col(i));
+
+  probabilities = pow(2 * M_PI, (double) mean.n_elem / -2.0) *
+      pow(det(cov), -0.5) * exp(exponents);
+}
+
 }; // namespace gmm
 }; // namespace mlpack
 




More information about the mlpack-svn mailing list