[mlpack-svn] r16959 - mlpack/trunk/src/mlpack/core/dists

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Aug 5 09:29:32 EDT 2014


Author: michaelfox99
Date: Tue Aug  5 09:29:32 2014
New Revision: 16959

Log:
Implemented Save, Load


Modified:
   mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.hpp

Modified: mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/dists/gaussian_distribution.hpp	Tue Aug  5 09:29:32 2014
@@ -1,6 +1,7 @@
 /**
  * @file gaussian_distribution.hpp
  * @author Ryan Curtin
+ * @author Michael Fox
  *
  * Implementation of the Gaussian distribution.
  */
@@ -8,8 +9,6 @@
 #define __MLPACK_METHODS_HMM_DISTRIBUTIONS_GAUSSIAN_DISTRIBUTION_HPP
 
 #include <mlpack/core.hpp>
-// Should be somewhere else, maybe in core.
-#include <mlpack/methods/gmm/phi.hpp>
 
 namespace mlpack {
 namespace distribution {
@@ -52,11 +51,17 @@
   /**
    * Return the probability of the given observation.
    */
-  double Probability(const arma::vec& observation) const
-  {
-    return mlpack::gmm::phi(observation, mean, covariance);
-  }
-
+  double Probability(const arma::vec& observation) const;
+  
+  /**
+   * Calculates the multivariate Gaussian probability density function for each
+   * data point (column) in the given matrix
+   *
+   * @param x List of observations.
+   * @param probabilities Output probabilities for each input observation.
+   */
+  void Probability(const arma::mat& x, arma::vec& probabilities) const;
+  
   /**
    * Return a randomly generated observation according to the probability
    * distribution defined by this object.
@@ -80,22 +85,69 @@
   void Estimate(const arma::mat& observations,
                 const arma::vec& probabilities);
 
-  //! Return the mean.
+  /**
+   * Return the mean.
+   */
   const arma::vec& Mean() const { return mean; }
-  //! Return a modifiable copy of the mean.
+
+  /**
+   * Return a modifiable copy of the mean.
+   */
   arma::vec& Mean() { return mean; }
 
-  //! Return the covariance matrix.
+  /**
+   * Return the covariance matrix.
+   */
   const arma::mat& Covariance() const { return covariance; }
-  //! Return a modifiable copy of the covariance.
+
+  /**
+   * Return a modifiable copy of the covariance.
+   */
   arma::mat& Covariance() { return covariance; }
 
   /**
    * Returns a string representation of this object.
    */
   std::string ToString() const;
+    
+  /*
+   * Save to or Load from SaveRestoreUtility
+   */
+  void Save(util::SaveRestoreUtility& n) const;
+  void Load(const util::SaveRestoreUtility& n);
+  static std::string const Type() { return "GaussianDistribution"; }
+  
+  
+    
 };
 
+/**
+* Calculates the multivariate Gaussian probability density function for each
+* data point (column) in the given matrix
+*
+* @param x List of observations.
+* @param probabilities Output probabilities for each input observation.
+*/
+inline void GaussianDistribution::Probability(const arma::mat& x,
+                                              arma::vec& probabilities) const
+{
+  // Column i of 'diffs' is the difference between x.col(i) and the mean.
+  arma::mat diffs = x - (mean * arma::ones<arma::rowvec>(x.n_cols));
+  
+  // Now, we only want to calculate the diagonal elements of (diffs' * cov^-1 *
+  // diffs).  We just don't need any of the other elements.  We can calculate
+  // the right hand part of the equation (instead of the left side) so that
+  // later we are referencing columns, not rows -- that is faster.
+  arma::mat rhs = -0.5 * inv(covariance) * diffs;
+  arma::vec exponents(diffs.n_cols); // We will now fill this.
+  for (size_t i = 0; i < diffs.n_cols; i++)
+    exponents(i) = exp(accu(diffs.unsafe_col(i) % rhs.unsafe_col(i)));
+  
+  probabilities = pow(2 * M_PI, (double) mean.n_elem / -2.0) *
+  pow(arma::det(covariance), -0.5) * exponents;
+}
+  
+
 }; // namespace distribution
 }; // namespace mlpack
 



More information about the mlpack-svn mailing list