[mlpack-svn] r10183 - in mlpack/trunk/src/mlpack/methods: . mog

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 12:13:17 EST 2011


Author: rcurtin
Date: 2011-11-08 12:13:17 -0500 (Tue, 08 Nov 2011)
New Revision: 10183

Modified:
   mlpack/trunk/src/mlpack/methods/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/mog/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/mog/kmeans.cpp
   mlpack/trunk/src/mlpack/methods/mog/mog_em.cpp
   mlpack/trunk/src/mlpack/methods/mog/mog_em.hpp
   mlpack/trunk/src/mlpack/methods/mog/mog_l2e.cpp
   mlpack/trunk/src/mlpack/methods/mog/mog_l2e.hpp
   mlpack/trunk/src/mlpack/methods/mog/phi.hpp
Log:
Update mog functions, eventually to be renamed GMMs.  Work in progress -- not
yet complete.


Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt	2011-11-08 17:13:17 UTC (rev 10183)
@@ -8,7 +8,7 @@
   infomax_ica
   # kernel_pca # (required sparse and is known to not work or compile)
   linear_regression
-  mog
+  #mog (in progress)
   #mvu  # (currently known to not work)
   naive_bayes
   nca

Modified: mlpack/trunk/src/mlpack/methods/mog/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/CMakeLists.txt	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/CMakeLists.txt	2011-11-08 17:13:17 UTC (rev 10183)
@@ -42,3 +42,12 @@
 target_link_libraries(mog_l2e
   mlpack
 )
+
+# Test executable.
+add_executable(gmm_test
+  gmm_test.cpp
+)
+target_link_libraries(gmm_test
+  mlpack
+  boost_unit_test_framework
+)

Modified: mlpack/trunk/src/mlpack/methods/mog/kmeans.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/kmeans.cpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/kmeans.cpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -7,6 +7,8 @@
  */
 #include "kmeans.hpp"
 
+#include <mlpack/core/kernels/lmetric.hpp>
+
 namespace mlpack {
 namespace gmm {
 
@@ -15,145 +17,162 @@
             std::vector<arma::vec>& means,
             std::vector<arma::mat>& covars,
             arma::vec& weights) {
-  // Set size of vectors and matrices properly.
-  means.resize(value_of_k);
-  covars.resize(value_of_k);
-  for (size_t i = 0; i < value_of_k; i++) {
-    means[i].set_size(value_of_k);
-    covars[i].set_size(value_of_k, value_of_k);
-  }
-  weights.set_size(value_of_k);
+  // Make sure we have more points than clusters.
+  if (value_of_k > data.n_cols)
+    Log::Warn << "k-means: more clusters requested than points given.  Empty"
+        << " clusters may result." << std::endl;
 
-  std::vector<arma::vec> mu, mu_old;
-  double* tmpssq = new double[value_of_k];
-  double* sig = new double[value_of_k];
-  double* sig_best = new double[value_of_k];
-  size_t* y = new size_t[value_of_k];
-  arma::vec x, diff;
-  arma::mat ssq;
-  size_t i, j, k, n, t, dim;
-  double score, score_old, sum;
+  // Assignment of cluster of each point.
+  arma::Col<size_t> assignments(data.n_cols); // Col used so we have shuffle().
+  // Centroids of each cluster.  Each column corresponds to a centroid.
+  arma::mat centroids(data.n_rows, value_of_k);
+  // Counts of points in each cluster.
+  arma::Col<size_t> counts(value_of_k);
 
-  n = data.n_cols;
-  dim = data.n_rows;
-  mu.resize(value_of_k);
-  mu_old.resize(value_of_k);
-  ssq.set_size(n, value_of_k);
+  // First we must randomly partition the dataset.
+  assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
+      value_of_k - 1, data.n_cols));
 
-  for (i = 0; i < value_of_k; i++) {
-    mu[i].set_size(dim);
-    mu_old[i].set_size(dim);
-  }
+  // Set counts correctly.
+  for (size_t i = 0; i < value_of_k; i++)
+    counts[i] = accu(assignments == i);
 
-  x.set_size(dim);
-  diff.set_size(dim);
+  size_t changed_assignments = 0;
+  do
+  {
+    // Update step.
+    // Calculate centroids based on given assignments.
+    centroids.zeros();
 
-  score_old = 999999;
+    for (size_t i = 0; i < data.n_cols; i++)
+      centroids.col(assignments[i]) += data.col(i);
 
-  // putting 5 random restarts to obtain the k-means
-  for (i = 0; i < 5; i++) {
-    t = -1;
-    for (k = 0; k < value_of_k; k++){
-      t = (t + 1 + (rand() % ((n - 1 - (value_of_k - k)) - (t + 1))));
-      mu[k] = data.col(t);
-      for(j = 0; j < n; j++) {
-        x = data.col(j);
-        diff = mu[k] - x;
-        ssq(j, k) = dot(diff, diff);
-      }
-    }
-    // This should be an Armadillo function, really.
-    double min_val = DBL_MAX;
-    for (i = 0; i < ssq.n_rows; i++) {
-      for (k = 0; k < ssq.n_cols; k++) {
-        if (ssq(i, k) < min_val) {
-          min_val = ssq(i, k);
-          y[i] = k;
+    for (size_t i = 0; i < value_of_k; i++)
+      centroids.col(i) /= counts[i];
+
+    // Assignment step.
+    // Find the closest centroid to each point.  We will keep track of how many
+    // assignments change.  When no assignments change, we are done.
+    changed_assignments = 0;
+    for (size_t i = 0; i < data.n_cols; i++)
+    {
+      // Find the closest centroid to this point.
+      double min_distance = std::numeric_limits<double>::infinity();
+      size_t closest_cluster = value_of_k; // Invalid value.
+
+      for (size_t j = 0; j < value_of_k; j++)
+      {
+        double distance = kernel::SquaredEuclideanDistance::Evaluate(
+            data.unsafe_col(i), centroids.unsafe_col(j));
+
+        if (distance < min_distance)
+        {
+          min_distance = distance;
+          closest_cluster = j;
         }
       }
+
+      // Reassign this point to the closest cluster.
+      if (assignments[i] != closest_cluster)
+      {
+        // Update counts.
+        counts[assignments[i]]--;
+        counts[closest_cluster]++;
+        // Update assignment.
+        assignments[i] = closest_cluster;
+        changed_assignments++;
+      }
     }
 
-    do {
-      for (k = 0; k < value_of_k; k++)
-        mu_old[k] = mu[k];
+    // Keep-bad-things-from-happening step.
+    // Ensure that no cluster is empty, and if so, take corrective action.
+    for (size_t i = 0; i < value_of_k; i++)
+    {
+      if (counts[i] == 0)
+      {
+        Log::Warn << "Cluster " << i << " is empty." << std::endl;
 
-      for(k = 0; k < value_of_k; k++) {
-        size_t p = 0;
-        mu[k].zeros();
-        for (j = 0; j < n; j++) {
-          x = data.col(j);
-          if (y[j] == k) {
-            mu[k] += x;
-            p++;
+        // Strategy: take the furthest point from the cluster with highest
+        // variance.  So, we need the variance of each cluster.
+        arma::vec variances;
+        variances.zeros(value_of_k);
+        for (size_t j = 0; j < data.n_cols; j++)
+          variances[assignments[j]] += var(data.col(j));
+
+        size_t cluster;
+        double max_var = 0;
+        for (size_t j = 0; j < value_of_k; j++)
+        {
+          if (variances[j] > max_var)
+          {
+            cluster = j;
+            max_var = variances[j];
           }
         }
 
-        if (p != 0)
-          mu[k] /= p;
+        // Now find the furthest point.
+        size_t point = data.n_cols; // Invalid.
+        double distance = 0;
+        for (size_t j = 0; j < data.n_cols; j++)
+        {
+          if (assignments[j] == cluster)
+          {
+            double d = kernel::SquaredEuclideanDistance::Evaluate(
+                data.unsafe_col(j), centroids.unsafe_col(cluster));
 
-        for (j = 0; j < n; j++) {
-          x = data.col(j);
-          diff = mu[k] - x;
-          ssq(j, k) = dot(diff, diff);
-        }
-      }
-      // This should be an Armadillo function, really.
-      min_val = DBL_MAX;
-      for (i = 0; i < ssq.n_rows; i++) {
-        for (k = 0; k < ssq.n_cols; k++) {
-          if (ssq(i, k) < min_val) {
-            min_val = ssq(i, k);
-            y[i] = k;
+            if (d >= distance)
+            {
+              distance = d;
+              point = j;
+            }
           }
         }
-      }
 
-      sum = 0;
-      for(k = 0; k < value_of_k; k++) {
-        diff = mu[k] - mu_old[k];
-        sum += dot(diff, diff);
-      }
+        Log::Warn << "Taking point " << point << " from cluster " << cluster
+            << std::endl;
+        Log::Warn << "Point: " << std::endl << data.col(point) << std::endl;
+        Log::Warn << "Cluster centroid: " << std::endl;
+        Log::Warn << centroids.col(cluster) << std::endl;
 
-    } while (sum != 0);
-
-    for (k = 0; k < value_of_k; k++) {
-      size_t p = 0;
-      tmpssq[k] = 0;
-      for (j = 0; j < n; j++) {
-        if (y[j] == k) {
-          tmpssq[k] += ssq(j, k);
-          p++;
-        }
+        // Take that point and add it to the empty cluster.
+        counts[cluster]--;
+        counts[i]++;
+        assignments[point] = i;
+        changed_assignments++;
       }
-      sig[k] = sqrt(tmpssq[k] / p);
     }
 
-    score = 0;
-    for(k = 0; k < value_of_k; k++) {
-      score += tmpssq[k];
-    }
-    score = score / n;
+  } while (changed_assignments > 0);
 
-    if (score < score_old) {
-      score_old = score;
-      for(k = 0; k < value_of_k; k++){
-        means[k] = mu[k];
-        sig_best[k] = sig[k];
+  // Now, with the centroids final, we need to find the covariance matrix of
+  // each cluster and then the a priori weight.  We also need to assign the
+  // means to be the centroids.  First, we must make sure the size of the
+  // vectors is correct.
+  means.resize(value_of_k);
+  covars.resize(value_of_k);
+  weights.set_size(value_of_k);
+  for (size_t i = 0; i < value_of_k; i++)
+  {
+    // Assign mean.
+    means[i] = centroids.col(i);
+
+    // Calculate covariance.
+    arma::mat data_subset(data.n_rows, accu(assignments == i));
+    size_t position = 0;
+    for (size_t j = 0; j < data.n_cols; j++)
+    {
+      if (assignments[j] == i)
+      {
+        data_subset.col(position) = data.col(j);
+        position++;
       }
     }
-  }
 
-  for (k = 0; k < value_of_k; k++) {
-    x.fill(sig_best[k]);
-    covars[k].diag() = x;
+    covars[i] = ccov(data_subset);
+
+    // Assign weight.
+    weights[i] = (double) accu(assignments == i) / (double) data.n_cols;
   }
-
-  weights.fill(1.0 / value_of_k);
-
-  delete[] tmpssq;
-  delete[] sig;
-  delete[] sig_best;
-  delete[] y;
 }
 
 }; // namespace gmm

Modified: mlpack/trunk/src/mlpack/methods/mog/mog_em.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/mog_em.cpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/mog_em.cpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -13,157 +13,146 @@
 using namespace mlpack;
 using namespace gmm;
 
-void MoGEM::ExpectationMaximization(const arma::mat& data_points) {
-  // Declaration of the variables */
-  size_t num_points;
-  size_t dim, num_gauss;
-  double sum, tmp;
-  std::vector<arma::vec> mu_temp, mu;
-  std::vector<arma::mat> sigma_temp, sigma;
-  arma::vec omega_temp, omega, x;
-  arma::mat cond_prob;
-  long double l, l_old, best_l, INFTY = 99999, TINY = 1.0e-10;
+void MoGEM::ExpectationMaximization(const arma::mat& data)
+{
+  // Create temporary models and set to the right size.
+  std::vector<arma::vec> means_trial(gaussians, arma::vec(dimension));
+  std::vector<arma::mat> covariances_trial(gaussians,
+      arma::mat(dimension, dimension));
+  arma::vec weights_trial(gaussians);
 
-  // Initializing values
-  dim = dimension();
-  num_gauss = number_of_gaussians();
-  num_points = data_points.n_cols;
+  arma::mat cond_prob(gaussians, data.n_cols);
 
-  // Initializing the number of the vectors and matrices
-  // according to the parameters input
-  mu_temp.resize(num_gauss);
-  mu.resize(num_gauss);
-  sigma_temp.resize(num_gauss);
-  sigma.resize(num_gauss);
-  omega_temp.set_size(num_gauss);
-  omega.set_size(num_gauss);
+  long double l, l_old, best_l, TINY = 1.0e-10;
 
-  // Allocating size to the vectors and matrices
-  // according to the dimensionality of the data
-  for(size_t i = 0; i < num_gauss; i++) {
-    mu_temp[i].set_size(dim);
-    mu[i].set_size(dim);
-    sigma_temp[i].set_size(dim, dim);
-    sigma[i].set_size(dim, dim);
-  }
-  x.set_size(dim);
-  cond_prob.set_size(num_gauss, num_points);
+  best_l = -DBL_MAX;
 
-  best_l = -INFTY;
-  size_t restarts = 0;
-  // performing 5 restarts and choosing the best from them
-  while (restarts < 5) {
+  // We will perform five trials, and then save the trial with the best result
+  // as our trained model.
+  for (size_t iteration = 0; iteration < 5; iteration++)
+  {
+    // Use k-means to find initial values for the parameters.
+    KMeans(data, gaussians, means_trial, covariances_trial, weights_trial);
 
-    // assign initial values to 'mu', 'sig' and 'omega' using k-means
-    KMeans(data_points, num_gauss, mu_temp, sigma_temp, omega_temp);
+    Log::Warn << "K-Means results:" << std::endl;
+    for (size_t i = 0; i < gaussians; i++)
+    {
+      Log::Warn << "Mean " << i << ":" << std::endl;
+      Log::Warn << means_trial[i] << std::endl;
+      Log::Warn << "Covariance " << i << ":" << std::endl;
+      Log::Warn << covariances_trial[i] << std::endl;
+    }
+    Log::Warn << "Weights: " << std::endl << weights_trial << std::endl;
 
-    l_old = -INFTY;
+    // Calculate the log likelihood of the model.
+    l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
 
-    // calculates the loglikelihood value
-    l = Loglikelihood(data_points, mu_temp, sigma_temp, omega_temp);
+    l_old = -DBL_MAX;
 
-    // added a check here to see if any
-    // significant change is being made
-    // at every iteration
-    while (l - l_old > TINY) {
-      // calculating the conditional probabilities
-      // of choosing a particular gaussian given
-      // the data and the present theta value
-      for (size_t j = 0; j < num_points; j++) {
-        x = data_points.col(j);
-        sum = 0;
-        for (size_t i = 0; i < num_gauss; i++) {
-          tmp = phi(x, mu_temp[i], sigma_temp[i]) * omega_temp[i];
-          cond_prob(i, j) = tmp;
-          sum += tmp;
+    // Iterate to update the model until no more improvement is found.
+    size_t max_iterations = 1000;
+    size_t iteration = 0;
+    while (std::abs(l - l_old) > TINY && iteration < max_iterations)
+    {
+      Log::Warn << "Iteration " << iteration << std::endl;
+      // Calculate the conditional probabilities of choosing a particular
+      // Gaussian given the data and the present theta value.
+      for (size_t j = 0; j < data.n_cols; j++)
+      {
+        for (size_t i = 0; i < gaussians; i++)
+        {
+          cond_prob(i, j) = phi(data.unsafe_col(j), means_trial[i],
+              covariances_trial[i]) * weights_trial[i];
         }
-        for (size_t i = 0; i < num_gauss; i++) {
-          tmp = cond_prob(i, j);
-          cond_prob(i, j) = tmp / sum;
-        }
+
+        // Normalize column to have sum probability of one.
+        cond_prob.col(j) /= arma::sum(cond_prob.col(j));
       }
 
-      // calculating the new value of the mu
-      // using the updated conditional probabilities
-      for (size_t i = 0; i < num_gauss; i++) {
-        sum = 0;
-        mu_temp[i].zeros();
-        for (size_t j = 0; j < num_points; j++) {
-          x = data_points.col(j);
-          mu_temp[i] = cond_prob(i, j) * x;
-          sum += cond_prob(i, j);
-        }
-        mu_temp[i] /= sum;
+      // Store the sums of each row because they are used multiple times.
+      arma::vec prob_row_sums = arma::sum(cond_prob, 1 /* row-wise */);
+
+      // Calculate the new value of the means using the updated conditional
+      // probabilities.
+      for (size_t i = 0; i < gaussians; i++)
+      {
+        means_trial[i].zeros();
+        for (size_t j = 0; j < data.n_cols; j++)
+          means_trial[i] += cond_prob(i, j) * data.col(j);
+
+        means_trial[i] /= prob_row_sums[i];
       }
 
-      // calculating the new value of the sig
-      // using the updated conditional probabilities
-      // and the updated mu
-      for (size_t i = 0; i < num_gauss; i++) {
-        sum = 0;
-        sigma_temp[i].zeros();
-        for (size_t j = 0; j < num_points; j++) {
-          arma::mat co, ro, c;
-          c.set_size(dim, dim);
-          x = data_points.col(j);
-          x -= mu_temp[i];
-          c = x * trans(x);
-          sigma_temp[i] += cond_prob(i, j) * c;
-          sum += cond_prob(i, j);
+      // Calculate the new value of the covariances using the updated
+      // conditional probabilities and the updated means.
+      for (size_t i = 0; i < gaussians; i++)
+      {
+        covariances_trial[i].zeros();
+        for (size_t j = 0; j < data.n_cols; j++)
+        {
+          arma::vec tmp = data.col(j) - means_trial[i];
+          covariances_trial[i] += cond_prob(i, j) * (tmp * trans(tmp));
         }
-        sigma_temp[i] /= sum;
+
+        covariances_trial[i] /= prob_row_sums[i];
       }
 
-      // calculating the new values for omega
-      // using the updated conditional probabilities
-      arma::vec identity_vector;
-      identity_vector.set_size(num_points);
-      identity_vector = (1.0 / num_points);
-      omega_temp = cond_prob * identity_vector;
+      // Calculate the new values for omega using the updated conditional
+      // probabilities.
+      weights_trial = prob_row_sums / data.n_cols;
+/*
+      Log::Warn << "Estimated weights:" << std::endl << weights_trial
+          << std::endl;
 
+      for (size_t i = 0; i < gaussians; i++)
+      {
+//        Log::Warn << "Estimated mean " << i << ":" << std::endl;
+//        Log::Warn << means_trial[i] << std::endl;
+        Log::Warn << "Estimated covariance " << i << ":" << std::endl;
+        Log::Warn << covariances_trial[i] << std::endl;
+      }
+*/
+
+      // Update values of l; calculate new log-likelihood.
       l_old = l;
-      l = Loglikelihood(data_points, mu_temp, sigma_temp, omega_temp);
+      l = Loglikelihood(data, means_trial, covariances_trial, weights_trial);
+
+      Log::Warn << "Improved log likelihood to " << l << std::endl;
+
+      iteration++;
     }
 
-    // putting a check to see if the best one is chosen
-    if (l > best_l) {
+    // The trial model is trained.  Is it better than our existing model?
+    if (l > best_l)
+    {
       best_l = l;
-      for (size_t i = 0; i < num_gauss; i++) {
-        mu[i] = mu_temp[i];
-        sigma[i] = sigma_temp[i];
-      }
-      omega = omega_temp;
+
+      means = means_trial;
+      covariances = covariances_trial;
+      weights = weights_trial;
     }
-    restarts++;
   }
 
-  for (size_t i = 0; i < num_gauss; i++) {
-    set_mu(i, mu[i]);
-    set_sigma(i, sigma[i]);
-  }
-  set_omega(omega);
-
   Log::Info << "Log likelihood value of the estimated model: " << best_l << "."
       << std::endl;
   return;
 }
 
-long double MoGEM::Loglikelihood(const arma::mat& data_points,
-                                 const std::vector<arma::vec>& means,
-                                 const std::vector<arma::mat>& covars,
-                                 const arma::vec& weights) {
-  size_t i, j;
-  arma::vec x;
-  long double likelihood, loglikelihood = 0;
+long double MoGEM::Loglikelihood(const arma::mat& data,
+                                 const std::vector<arma::vec>& means_l,
+                                 const std::vector<arma::mat>& covariances_l,
+                                 const arma::vec& weights_l) const
+{
+  long double loglikelihood = 0;
+  long double likelihood;
 
-  x.set_size(data_points.n_rows);
-
-  for (j = 0; j < data_points.n_cols; j++) {
-    x = data_points.col(j);
+  for (size_t j = 0; j < data.n_cols; j++)
+  {
     likelihood = 0;
-    for(i = 0; i < number_of_gaussians_; i++) {
-      likelihood += weights(i) * phi(x, means[i], covars[i]);
-    }
+    for(size_t i = 0; i < gaussians; i++)
+      likelihood += weights_l(i) * phi(data.unsafe_col(j), means_l[i],
+          covariances_l[i]);
+
     loglikelihood += log(likelihood);
   }
 

Modified: mlpack/trunk/src/mlpack/methods/mog/mog_em.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/mog_em.hpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/mog_em.hpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -40,85 +40,52 @@
  */
 class MoGEM {
  private:
-  // The parameters of the mixture model
-  std::vector<arma::vec> mu_;
-  std::vector<arma::mat> sigma_;
-  arma::vec omega_;
-  size_t number_of_gaussians_;
-  size_t dimension_;
+  //! The number of Gaussians in the model.
+  size_t gaussians;
+  //! The dimensionality of the model.
+  size_t dimension;
+  //! Vector of means; one for each Gaussian.
+  std::vector<arma::vec> means;
+  //! Vector of covariances; one for each Gaussian.
+  std::vector<arma::mat> covariances;
+  //! Vector of a priori weights for each Gaussian.
+  arma::vec weights;
 
  public:
+  /**
+   * Create a GMM with the given number of Gaussians, each of which have the
+   * specified dimensionality.
+   *
+   * @param gaussians Number of Gaussians in this GMM.
+   * @param dimension Dimensionality of each Gaussian.
+   */
+  MoGEM(size_t gaussians, size_t dimension) :
+      gaussians(gaussians),
+      dimension(dimension),
+      means(gaussians),
+      covariances(gaussians) { /* nothing to do */ }
 
-  MoGEM() { }
+  //! Return the number of gaussians in the model.
+  const size_t Gaussians() { return gaussians; }
 
-  ~MoGEM() { }
+  //! Return the dimensionality of the model.
+  const size_t Dimension() { return dimension; }
 
-  void Init(size_t num_gauss, size_t dimension) {
-    // Initialize the private variables
-    number_of_gaussians_ = num_gauss;
-    dimension_ = dimension;
+  //! Return a const reference to the vector of means (mu).
+  const std::vector<arma::vec>& Means() const { return means; }
+  //! Return a reference to the vector of means (mu).
+  std::vector<arma::vec>& Means() { return means; }
 
-    // Resize the ArrayList of Vectors and Matrices
-    mu_.resize(number_of_gaussians_);
-    sigma_.resize(number_of_gaussians_);
-  }
+  //! Return a const reference to the vector of covariance matrices (sigma).
+  const std::vector<arma::mat>& Covariances() const { return covariances; }
+  //! Return a reference to the vector of covariance matrices (sigma).
+  std::vector<arma::mat>& Covariances() { return covariances; }
 
-  std::vector<arma::vec>& mu() {
-    return mu_;
-  }
+  //! Return a const reference to the a priori weights of each Gaussian.
+  const arma::vec& Weights() const { return weights; }
+  //! Return a reference to the a priori weights of each Gaussian.
+  arma::vec& Weights() { return weights; }
 
-  std::vector<arma::mat>& sigma() {
-    return sigma_;
-  }
-
-  arma::vec& omega() {
-    return omega_;
-  }
-
-  size_t number_of_gaussians() {
-    return number_of_gaussians_;
-  }
-
-  size_t dimension() {
-    return dimension_;
-  }
-
-  arma::vec& mu(size_t i) {
-    return mu_[i] ;
-  }
-
-  arma::mat& sigma(size_t i) {
-    return sigma_[i];
-  }
-
-  double omega(size_t i) {
-    return omega_[i];
-  }
-
-  // The set functions
-
-  void set_mu(size_t i, arma::vec& mu) {
-    assert(i < number_of_gaussians_);
-    assert(mu.n_elem == dimension_);
-
-    mu_[i] = mu;
-  }
-
-  void set_sigma(size_t i, arma::mat& sigma) {
-    assert(i < number_of_gaussians_);
-    assert(sigma.n_rows == dimension_);
-    assert(sigma.n_cols == dimension_);
-
-    sigma_[i] = sigma;
-  }
-
-  void set_omega(arma::vec& omega) {
-    assert(omega.n_elem == number_of_gaussians());
-
-    omega_ = omega;
-    return;
-  }
-
   /**
    * This function outputs the parameters of the model
    * to an arraylist of doubles
@@ -128,65 +95,29 @@
    * mog.OutputResults(&results);
    * @endcode
    */
-  void OutputResults(std::vector<double>& results) {
-
+  void OutputResults(std::vector<double>& results)
+  {
     // Initialize the size of the output array
-    results.resize(number_of_gaussians_ * (1 + dimension_ * (1 + dimension_)));
+    results.resize(gaussians * (1 + dimension * (1 + dimension)));
 
     // Copy values to the array from the private variables of the class
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      results[i] = omega_[i];
-      for (size_t j = 0; j < dimension_; j++) {
-        results[number_of_gaussians_ + (i * dimension_) + j] = (mu_[i])[j];
-        for (size_t k = 0; k < dimension_; k++) {
-          results[number_of_gaussians_ * (1 + dimension_) +
-              (i * dimension_ * dimension_) + (j * dimension_) + k] =
-              (sigma_[i])(j, k);
+    for (size_t i = 0; i < gaussians; i++)
+    {
+      results[i] = weights[i];
+      for (size_t j = 0; j < dimension; j++)
+      {
+        results[gaussians + (i * dimension) + j] = (means[i])[j];
+        for (size_t k = 0; k < dimension; k++)
+        {
+          results[gaussians * (1 + dimension) +
+              (i * dimension * dimension) + (j * dimension) + k] =
+              (covariances[i])(j, k);
         }
       }
     }
   }
 
   /**
-   * This function prints the parameters of the model
-   *
-   * @code
-   * mog.Display();
-   * @endcode
-   */
-  void Display() {
-    // Output the model parameters as the omega, mu and sigma
-    Log::Info << " Omega : [ ";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      Log::Info << omega_[i] << " ";
-    }
-    Log::Info << "]" << std::endl;
-
-    Log::Info << " Mu : " << std::endl << "[";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      for (size_t j = 0; j < dimension_ ; j++) {
-        Log::Info << (mu_[i])[j];
-      }
-      Log::Info << ";";
-      if (i == (number_of_gaussians_ - 1)) {
-        Log::Info << "\b]" << std::endl;
-      }
-    }
-    Log::Info << "Sigma : ";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      Log::Info << std::endl << "[";
-      for (size_t j = 0; j < dimension_ ; j++) {
-        for(size_t k = 0; k < dimension_ ; k++) {
-          Log::Info << (sigma_[i])(j, k);
-        }
-        Log::Info << ";";
-      }
-      Log::Info << "\b]";
-    }
-    Log::Info << std::endl;
-  }
-
-  /**
    * This function calculates the parameters of the model
    * using the Maximum Likelihood function via the
    * Expectation Maximization (EM) Algorithm.
@@ -209,10 +140,10 @@
   long double Loglikelihood(const arma::mat& data_points,
                             const std::vector<arma::vec>& means,
                             const std::vector<arma::mat>& covars,
-                            const arma::vec& weights);
+                            const arma::vec& weights) const;
 };
 
 }; // namespace gmm
 }; // namespace mlpack
 
-#endif // __MLPACK_METHODS_MOG_MOG_EM_HPP
+#endif

Modified: mlpack/trunk/src/mlpack/methods/mog/mog_l2e.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/mog_l2e.cpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/mog_l2e.cpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -2,9 +2,7 @@
  * @author Parikshit Ram (pram at cc.gatech.edu)
  * @file mog_l2e.cpp
  *
- * Implementation for L2 loss function, and
- * also some initial points generator
- *
+ * Implementation for L2 loss function, and also some initial points generator.
  */
 #include "mog_l2e.hpp"
 #include "phi.hpp"
@@ -13,169 +11,151 @@
 using namespace mlpack;
 using namespace gmm;
 
-long double MoGL2E::L2Error(const arma::mat& data) {
-  long double reg, fit, l2e;
-
-  reg = RegularizationTerm_();
-  fit = GoodnessOfFitTerm_(data);
-  l2e = reg - (2 * fit) / data.n_cols;
-
-  return l2e;
+long double MoGL2E::L2Error(const arma::mat& data)
+{
+  return RegularizationTerm_() - (2 * GoodnessOfFitTerm_(data)) / data.n_colse;
 }
 
-long double MoGL2E::L2Error(const arma::mat& data, arma::vec& gradients) {
-
-  long double reg, fit, l2e;
-
+long double MoGL2E::L2Error(const arma::mat& data, arma::vec& gradients)
+{
   arma::vec g_reg, g_fit;
-  reg = RegularizationTerm_(g_reg);
-  fit = GoodnessOfFitTerm_(data, g_fit);
 
+  long double l2e = RegularizationTerm_(g_reg) -
+      (2 * GoodnessOfFitTerm_(data, g_fit)) / data.n_cols;
+
   gradients = g_reg - (2 * g_fit) / data.n_cols;
 
-  l2e = reg - (2 * fit) / data.n_cols;
-
   return l2e;
 }
 
-long double MoGL2E::RegularizationTerm_() {
-  arma::mat phi_mu, sum_covar;
-  arma::vec x;
-  long double reg, tmpVal;
+long double MoGL2E::RegularizationTerm_()
+{
+  arma::mat phi_mu(gaussians, gaussians)
 
-  phi_mu.set_size(number_of_gaussians_, number_of_gaussians_);
-  sum_covar.set_size(dimension_, dimension_);
-  x = omega_;
+  // Fill the phi_mu matrix (which is symmetric).  Each entry of the matrix is
+  // the phi() function evaluated on the difference between the means of each
+  // Gaussian using the sum of the covariances of the two mixtures.
+  for (size_t k = 1; k < gaussians; k++)
+  {
+    for (size_t j = 0; j < k; j++)
+    {
+      long double tmpVal = phi(means[k], means[j], covariances[k] +
+          covariances[j]);
 
-  for (size_t k = 1; k < number_of_gaussians_; k++) {
-    for (size_t j = 0; j < k; j++) {
-      sum_covar = sigma_[k] + sigma_[j];
-
-      tmpVal = phi(mu_[k], mu_[j], sum_covar);
       phi_mu(j, k) = tmpVal;
       phi_mu(k, j) = tmpVal;
     }
   }
 
-  for(size_t k = 0; k < number_of_gaussians_; k++) {
-    sum_covar = 2 * sigma_[k];
+  // Because the difference between the means is 0, save a little time by only
+  // doing the part of the calculation we need to (instead of calling phi()).
+  for(size_t k = 0; k < gaussians; k++)
+    phi_mu(k, k) = pow(2 * M_PI, (double) means[k].n_elem / -2.0)
+        * pow(det(2 * covariances[k]), -0.5);
 
-    phi_mu(k, k) = phi(mu_[k], mu_[k], sum_covar);
-  }
-
-  // Calculating the reg value
-  reg = dot(x, x * phi_mu);
-
-  return reg;
+  return dot(weights, weights * phi_mu);
 }
 
-long double MoGL2E::RegularizationTerm_(arma::vec& g_reg) {
-  arma::mat phi_mu, sum_covar;
+long double MoGL2E::RegularizationTerm_(arma::vec& g_reg)
+{
+  arma::mat phi_mu(gaussians, gaussians);
   arma::vec x, y;
   long double reg, tmpVal;
 
   arma::vec df_dw, g_omega;
-  std::vector<arma::vec> g_mu, g_sigma;
-  std::vector<std::vector<arma::vec> > dp_d_mu, dp_d_sigma;
 
-  phi_mu.set_size(number_of_gaussians_, number_of_gaussians_);
-  sum_covar.set_size(dimension_, dimension_);
-  x = omega_;
+  std::vector<arma::vec> g_mu(gaussians, arma::vec(dimension));
+  std::vector<arma::vec> g_sigma(gaussians, arma::vec((dimension * (dimension
+      + 1)) / 2));
 
-  g_mu.resize(number_of_gaussians_);
-  g_sigma.resize(number_of_gaussians_);
-  dp_d_mu.resize(number_of_gaussians_);
-  dp_d_sigma.resize(number_of_gaussians_);
-  for(size_t k = 0; k < number_of_gaussians_; k++){
-    dp_d_mu[k].resize(number_of_gaussians_);
-    dp_d_sigma[k].resize(number_of_gaussians_);
-  }
+  std::vector<std::vector<arma::vec> > dp_d_mu(gaussians,
+      std::vector<arma::vec>(gaussians));
+  std::vector<std::vector<arma::vec> > dp_d_sigma(gaussians,
+      std::vector<arma::vec>(gaussians));
 
-  for(size_t k = 1; k < number_of_gaussians_; k++) {
-    for(size_t j = 0; j < k; j++) {
-      sum_covar = sigma_[k] * sigma_[j];
+  x = weights;
 
-      std::vector<arma::mat> tmp_d_cov;
+  // Fill the phi_mu matrix (which is symmetric).  Each entry of the matrix is
+  // the phi() function evaluated on the difference between the means of each
+  // Gaussian using the sum of the covariances of the two mixtures.
+  for(size_t k = 1; k < gaussians; k++)
+  {
+    for(size_t j = 0; j < k; j++)
+    {
+      std::vector<arma::mat> tmp_d_cov(dimension * (dimension + 1));
       arma::vec tmp_dp_d_sigma;
 
-      tmp_d_cov.resize(dimension_ * (dimension_ + 1));
-
-      for(size_t i = 0; i < (dimension_ * (dimension_ + 1) / 2); i++) {
-        tmp_d_cov[i] = (d_sigma_[k])[i];
-        tmp_d_cov[(dimension_ * (dimension_ + 1) / 2) + i] = (d_sigma_[j])[i];
+      // We should find a way to avoid all this copying to set up for the call
+      // to phi().
+      for(size_t i = 0; i < (dimension * (dimension + 1) / 2); i++)
+      {
+        tmp_d_cov[i] = (covariancesGradients[k])[i];
+        tmp_d_cov[(dimension * (dimension + 1) / 2) + i] =
+            (covariancesGradients[j])[i];
       }
 
-      tmpVal = phi(mu_[k], mu_[j], sum_covar, tmp_d_cov, dp_d_mu[j][k],
-          tmp_dp_d_sigma);
+      tmpVal = phi(means[k], means[j], covariances[k] + covariances[j],
+          tmp_d_cov, dp_d_mu[j][k], tmp_dp_d_sigma);
 
       phi_mu(j, k) = tmpVal;
       phi_mu(k, j) = tmpVal;
 
       dp_d_mu[k][j] = -dp_d_mu[j][k];
 
-      arma::vec tmp_dp_1(tmp_dp_d_sigma.n_elem / 2);
-      arma::vec tmp_dp_2(tmp_dp_d_sigma.n_elem / 2);
-      for (size_t i = 0; i < tmp_dp_1.n_elem; i++) {
-        tmp_dp_1[i] = tmp_dp_d_sigma[i];
-        tmp_dp_2[i] = tmp_dp_d_sigma[(dimension_ * (dimension_ + 1) / 2) + i];
-      }
-
-      dp_d_sigma[j][k] = tmp_dp_1;
-      dp_d_sigma[k][j] = tmp_dp_2;
+      dp_d_sigma[j][k] = tmp_dp_d_sigma.rows(0,
+          (dimension * (dimension + 1) / 2) - 1);
+      dp_d_sigma[k][j] = tmp_dp_d_sigma.rows((dimension * (dimension + 1) / 2),
+          tmp_dp_d_sigma.n_rows - 1);
     }
   }
 
-  for (size_t k = 0; k < number_of_gaussians_; k++) {
-    sum_covar = 2 * sigma_[k];
+  // Fill the diagonal elements of the phi_mu matrix.
+  for (size_t k = 0; k < gaussians; k++)
+  {
+    arma::vec junk; // This result is not needed.
+    phi_mu(k, k) = phi(means[k], means[k], 2 * covariances[k],
+        covariancesGradients[k], junk, dp_d_sigma[k][k]);
 
-    arma::vec junk;
-    tmpVal = phi(mu_[k], mu_[k], sum_covar, d_sigma_[k], junk,
-        dp_d_sigma[k][k]);
-
-    phi_mu(k, k) = tmpVal;
-
-    dp_d_mu[k][k].zeros(dimension_);
+    dp_d_mu[k][k].zeros(dimension);
   }
 
-  // Calculating the reg value
-  reg = dot(x, x * phi_mu);
+  // Calculate the regularization term value.
+  arma::vec y = weights * phi_mu;
+  long double reg = dot(weights, y);
 
-  // Calculating the g_omega values - a vector of size K-1
+  // Calculate the g_omega value; a vector of size K - 1
   df_dw = 2.0 * y;
-  g_omega = d_omega_ * df_dw;
+  g_omega = weightsGradients * df_dw;
 
-  // Calculating the g_mu values - K vectors of size D
-  for (size_t k = 0; k < number_of_gaussians_; k++) {
-    g_mu[k].zeros(dimension_);
+  // Calculate the g_mu values; K vectors of size D
+  for (size_t k = 0; k < gaussians; k++)
+  {
+    for (size_t j = 0; j < gaussians; j++)
+      g_mu[k] = 2.0 * weights[k] * weights[j] * dp_d_mu[j][k];
 
-    for (size_t j = 0; j < number_of_gaussians_; j++) {
-      g_mu[k] += x[j] * dp_d_mu[j][k];
-      g_mu[k] *= 2.0 * x[k];
-    }
-
     // Calculating the g_sigma values - K vectors of size D(D+1)/2
-    for (size_t k = 0; k < number_of_gaussians_; k++) {
-      g_sigma[k].zeros((dimension_ * (dimension_ + 1)) / 2);
-      for (size_t j = 0; j < number_of_gaussians_; j++)
+    for (size_t k = 0; k < gaussians; k++)
+    {
+      for (size_t j = 0; j < gaussians; j++)
         g_sigma[k] += x[k] * dp_d_sigma[j][k];
       g_sigma[k] *= 2.0 * x[k];
     }
 
     // Making the single gradient vector of size K*(D+1)*(D+2)/2 - 1
-    arma::vec tmp_g_reg((number_of_gaussians_ * (dimension_ + 1) *
-        (dimension_ * 2) / 2) - 1);
+    arma::vec tmp_g_reg((gaussians * (dimension + 1) *
+        (dimension * 2) / 2) - 1);
     size_t j = 0;
     for (size_t k = 0; k < g_omega.n_elem; k++)
       tmp_g_reg[k] = g_omega[k];
     j = g_omega.n_elem;
 
-    for (size_t k = 0; k < number_of_gaussians_; k++) {
-      for (size_t i = 0; i < dimension_; i++)
-        tmp_g_reg[j + (k * dimension_) + i] = (g_mu[k])[i];
+    for (size_t k = 0; k < gaussians; k++) {
+      for (size_t i = 0; i < dimension; i++)
+        tmp_g_reg[j + (k * dimension) + i] = (g_mu[k])[i];
 
-      for(size_t i = 0; i < (dimension_ * (dimension_ + 1) / 2); i++) {
-        tmp_g_reg[j + (number_of_gaussians_ * dimension_)
-            + k * (dimension_ * (dimension_ + 1) / 2)
+      for(size_t i = 0; i < (dimension * (dimension + 1) / 2); i++) {
+        tmp_g_reg[j + (gaussians * dimension)
+            + k * (dimension * (dimension + 1) / 2)
             + i] = (g_sigma[k])[i];
       }
     }
@@ -188,16 +168,16 @@
 
 long double MoGL2E::GoodnessOfFitTerm_(const arma::mat& data) {
   long double fit;
-  arma::mat phi_x(number_of_gaussians_, data.n_cols);
+  arma::mat phi_x(gaussians, data.n_cols);
   arma::vec identity_vector;
 
   identity_vector.ones(data.n_cols);
 
-  for (size_t k = 0; k < number_of_gaussians_; k++)
+  for (size_t k = 0; k < gaussians; k++)
     for (size_t i = 0; i < data.n_cols; i++)
-      phi_x(k, i) = phi(data.unsafe_col(i), mu_[k], sigma_[k]);
+      phi_x(k, i) = phi(data.unsafe_col(i), means[k], covariances[k]);
 
-  fit = dot(omega_ * phi_x, identity_vector);
+  fit = dot(weights * phi_x, identity_vector);
 
   return fit;
 }
@@ -205,55 +185,55 @@
 long double MoGL2E::GoodnessOfFitTerm_(const arma::mat& data,
                                        arma::vec& g_fit) {
   long double fit;
-  arma::mat phi_x(number_of_gaussians_, data.n_cols);
-  arma::vec weights, x, y, identity_vector;
-  arma::vec g_omega,tmp_g_omega;
+  arma::mat phi_x(gaussians, data.n_cols);
+  arma::vec weights_l, x, y, identity_vector;
+  arma::vec g_omega, tmp_g_omega;
   std::vector<arma::vec> g_mu, g_sigma;
 
-  weights = omega_;
+  weights_l = weights;
   x.set_size(data.n_rows);
   identity_vector.ones(data.n_cols);
 
-  g_mu.resize(number_of_gaussians_);
-  g_sigma.resize(number_of_gaussians_);
+  g_mu.resize(gaussians);
+  g_sigma.resize(gaussians);
 
-  for(size_t k = 0; k < number_of_gaussians_; k++) {
-    g_mu[k].zeros(dimension_);
-    g_sigma[k].zeros(dimension_ * (dimension_ + 1) / 2);
+  for(size_t k = 0; k < gaussians; k++) {
+    g_mu[k].zeros(dimension);
+    g_sigma[k].zeros(dimension * (dimension + 1) / 2);
 
     for (size_t i = 0; i < data.n_cols; i++) {
       arma::vec tmp_g_mu, tmp_g_sigma;
-      phi_x(k, i) = phi(data.unsafe_col(i), mu_[k], sigma_[k], d_sigma_[k],
-          tmp_g_mu, tmp_g_sigma);
+      phi_x(k, i) = phi(data.unsafe_col(i), means[k], covariances[k],
+          d_sigma_[k], tmp_g_mu, tmp_g_sigma);
 
       g_mu[k] += tmp_g_mu;
       g_sigma[k] = tmp_g_sigma;
     }
 
-    g_mu[k] *= weights[k];
-    g_sigma[k] *= weights[k];
+    g_mu[k] *= weights_l[k];
+    g_sigma[k] *= weights_l[k];
   }
 
-  fit = dot(weights * phi_x, identity_vector);
+  fit = dot(weights_l * phi_x, identity_vector);
 
   // Calculating the g_omega
   tmp_g_omega = phi_x * identity_vector;
   g_omega = d_omega_ * tmp_g_omega;
 
   // Making the single gradient vector of size K*(D+1)*(D+2)/2
-  arma::vec tmp_g_fit((number_of_gaussians_ * (dimension_ + 1) *
-      (dimension_ * 2) / 2) - 1);
+  arma::vec tmp_g_fit((gaussians * (dimension + 1) *
+      (dimension * 2) / 2) - 1);
   size_t j = 0;
   for (size_t k = 0; k < g_omega.n_elem; k++)
     tmp_g_fit[k] = g_omega[k];
   j = g_omega.n_elem;
-  for (size_t k = 0; k < number_of_gaussians_; k++) {
-    for (size_t i = 0; i < dimension_; i++)
-      tmp_g_fit[j + (k * dimension_) + i] = (g_mu[k])[i];
+  for (size_t k = 0; k < gaussians; k++) {
+    for (size_t i = 0; i < dimension; i++)
+      tmp_g_fit[j + (k * dimension) + i] = (g_mu[k])[i];
 
-    for (size_t i = 0; i < (dimension_ * (dimension_ + 1) / 2); i++)
-      tmp_g_fit[j + number_of_gaussians_ * dimension_
-        + k * (dimension_ * (dimension_ + 1) / 2) + i] = (g_sigma[k])[i];
+    for (size_t i = 0; i < (dimension * (dimension + 1) / 2); i++)
+      tmp_g_fit[j + gaussians * dimension
+        + k * (dimension * (dimension + 1) / 2) + i] = (g_sigma[k])[i];
   }
 
   g_fit = tmp_g_fit;
@@ -291,23 +271,23 @@
 void MoGL2E::InitialPointGenerator(arma::vec& theta,
                                    const arma::mat& data,
                                    size_t k_comp) {
-  std::vector<arma::vec> means;
+  std::vector<arma::vec> means_l;
   std::vector<arma::mat> covars;
-  arma::vec weights;
+  arma::vec weights_l;
   double noise;
 
-  weights.set_size(k_comp);
-  means.resize(k_comp);
+  weights_l.set_size(k_comp);
+  means_l.resize(k_comp);
   covars.resize(k_comp);
 
   theta.set_size(k_comp);
 
   for (size_t i = 0; i < k_comp; i++) {
-    means[i].set_size(data.n_rows);
+    means_l[i].set_size(data.n_rows);
     covars[i].set_size(data.n_rows, data.n_rows);
   }
 
-  KMeans(data, k_comp, means, covars, weights);
+  KMeans(data, k_comp, means_l, covars, weights_l);
 
   for (size_t k = 0; k < k_comp - 1; k++) {
     noise = (double) (rand() % 10000) / (double) 1000;
@@ -316,7 +296,7 @@
 
   for (size_t k = 0; k < k_comp; k++) {
     for (size_t j = 0; j < data.n_rows; j++)
-      theta[k_comp - 1 + k * data.n_rows + j] = (means[k])[j];
+      theta[k_comp - 1 + k * data.n_rows + j] = (means_l[k])[j];
 
     arma::mat u = chol(covars[k]);
     for(size_t j = 0; j < data.n_rows; j++)

Modified: mlpack/trunk/src/mlpack/methods/mog/mog_l2e.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/mog_l2e.hpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/mog_l2e.hpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -2,9 +2,7 @@
  * @author Parikshit Ram (pram at cc.gatech.edu)
  * @file mog_l2e.hpp
  *
- * Defines a Gaussian Mixture model and
- * estimates the parameters of the model
- *
+ * Defines a Gaussian Mixture model and estimates the parameters of the model.
  */
 #ifndef __MLPACK_METHODS_MOG_MOG_L2E_HPP
 #define __MLPACK_METHODS_MOG_MOG_L2E_HPP
@@ -24,17 +22,15 @@
 /**
  * A Gaussian mixture model class.
  *
- * This class uses L2 loss function to
- * estimate the parameters of a gaussian mixture
- * model on a given data.
+ * This class uses L2 loss function to estimate the parameters of a gaussian
+ * mixture model on a given data.
  *
- * The parameters are converted for optimization
- * to maintain the following facts:
- * - the weights sum to one
- *  - for this, the weights were parameterized using
- *    the logistic function
- * - the covariance matrix is always positive definite
- *  - for this, the Cholesky decomposition is used
+ * The parameters are converted for optimization to maintain the following
+ * facts:
+ * - the weights sum to one; for this, the weights were parameterized using
+ *     the logistic function
+ * - the covariance matrix is always positive definite; for this, the Cholesky
+ *     decomposition is used
  *
  * Example use:
  *
@@ -51,44 +47,32 @@
 class MoGL2E {
  private:
   // The parameters of the Mixture model
-  std::vector<arma::vec> mu_;
-  std::vector<arma::mat> sigma_;
-  arma::vec omega_;
-  size_t number_of_gaussians_;
-  size_t dimension_;
+  size_t gaussians;
+  size_t dimension;
+  std::vector<arma::vec> means;
+  std::vector<arma::mat> covariances;
+  arma::vec weights;
 
   // The differential for the paramterization
   // for optimization
-  arma::mat d_omega_;
-  std::vector<std::vector<arma::mat> > d_sigma_;
+  arma::mat weightsGradients;
+  std::vector<std::vector<arma::mat> > covariancesGradients;
 
  public:
 
-  MoGL2E() { }
+  MoGL2E(size_t gaussians, size_t dimension) :
+    gaussians(gaussians),
+    dimension(dimension),
+    means(gaussians),
+    covariances(gaussians) { /* nothing to do */ }
 
-  ~MoGL2E() { }
-
-  void Init(size_t num_gauss, size_t dimension) {
-    // Destruct everything to initialize afresh
-    mu_.clear();
-    sigma_.clear();
-    d_sigma_.clear();
-
-    // Initialize the private variables
-    number_of_gaussians_ = num_gauss;
-    dimension_ = dimension;
-
-    // Resize the vector of vectors and matrices
-    mu_.resize(number_of_gaussians_);
-    sigma_.resize(number_of_gaussians_);
+  void Resize_d_sigma_()
+  {
+    d_sigma_.resize(gaussians);
+    for(size_t i = 0; i < gaussians; i++)
+      d_sigma_[i].resize(dimension * (dimension + 1) / 2);
   }
 
-  void Resize_d_sigma_() {
-    d_sigma_.resize(number_of_gaussians_);
-    for(size_t i = 0; i < number_of_gaussians_; i++)
-      d_sigma_[i].resize(dimension_ * (dimension_ + 1) / 2);
-  }
-
   /**
    * This function uses the parameters used for optimization
    * and converts it into athe parameters of a Gaussian
@@ -109,7 +93,8 @@
     arma::mat lower_triangle_matrix;
     double sum, s_min = 0.01;
 
-    Init(num_mods, dimension);
+    gaussians = num_mods;
+    this->dimension = dimension;
     lower_triangle_matrix.set_size(dimension, dimension);
 
     // calculating the omega values
@@ -118,34 +103,34 @@
     sum = accu(temp_array);
 
     temp_array /= sum;
-    set_omega(temp_array);
+    weights = temp_array;
 
     // calculating the mu values
-    for(size_t k = 0; k < num_mods; k++) {
+    for(size_t k = 0; k < gaussians; k++) {
       for(size_t j = 0; j < dimension; j++) {
         temp_mu[j] = theta[num_mods + (k * dimension) + j - 1];
       }
-      set_mu(k, temp_mu);
+      means[k] = temp_mu;
     }
 
     // calculating the sigma values
     // using a lower triangular matrix and its transpose
     // to obtain a positive definite symmetric matrix
     arma::mat sigma_temp(dimension, dimension);
-    for(size_t k = 0; k < num_mods; k++) {
+    for(size_t k = 0; k < gaussians; k++) {
       lower_triangle_matrix.zeros();
       for(size_t j = 0; j < dimension; j++) {
         for(size_t i = 0; i < j; i++) {
-          lower_triangle_matrix(j, i) = theta[(num_mods - 1)
-              + (num_mods * dimension) + k * (dimension * (dimension + 1) / 2)
+          lower_triangle_matrix(j, i) = theta[(gaussians - 1)
+              + (gaussians * dimension) + k * (dimension * (dimension + 1) / 2)
               + (j * (j + 1) / 2) + i];
         }
-        lower_triangle_matrix(j, j) = theta[(num_mods - 1)
-            + (num_mods * dimension) + k * (dimension * (dimension + 1) / 2)
+        lower_triangle_matrix(j, j) = theta[(gaussians - 1)
+            + (gaussians * dimension) + k * (dimension * (dimension + 1) / 2)
             + (j * (j + 1) / 2) + j] + s_min;
       }
       sigma_temp = lower_triangle_matrix * trans(lower_triangle_matrix);
-      set_sigma(k, sigma_temp);
+      covariances[k] = sigma_temp;
     }
   }
 
@@ -170,7 +155,8 @@
     arma::mat lower_triangle_matrix;
     double sum, s_min = 0.01;
 
-    Init(num_mods, dimension);
+    gaussians = num_mods;
+    this->dimension = dimension;
     lower_triangle_matrix.set_size(dimension, dimension);
 
     // calculating the omega values
@@ -179,21 +165,21 @@
     sum = accu(temp_array);
 
     temp_array /= sum;
-    set_omega(temp_array);
+    weights = temp_array;
 
     // calculating the d_omega values
     arma::mat d_omega_temp(num_mods - 1, num_mods);
     d_omega_temp.zeros();
     for (size_t i = 0; i < num_mods - 1; i++) {
       for (size_t j = 0; j < i; j++) {
-        d_omega_temp(i, j) = -(omega_[i] * omega_[j]);
-        d_omega_temp(j, i) = -(omega_[i] * omega_[j]);
+        d_omega_temp(i, j) = -(weights[i] * weights[j]);
+        d_omega_temp(j, i) = -(weights[i] * weights[j]);
       }
-      d_omega_temp(i, i) = omega_[i] * (1 - omega_[i]);
+      d_omega_temp(i, i) = weights[i] * (1 - weights[i]);
     }
 
     for (size_t i = 0; i < num_mods - 1; i++)
-      d_omega_temp(i, num_mods - 1) = -(omega_[i] * omega_[num_mods - 1]);
+      d_omega_temp(i, num_mods - 1) = -(weights[i] * weights[num_mods - 1]);
 
     set_d_omega(d_omega_temp);
 
@@ -201,7 +187,7 @@
     for (size_t k = 0; k < num_mods; k++) {
       for (size_t j = 0; j < dimension; j++)
         temp_mu[j] = theta[num_mods + (k * dimension) + j - 1];
-      set_mu(k, temp_mu);
+      means[k] = temp_mu;
     }
     // d_mu is not computed because it is implicitly known
     // since no parameterization is applied on them
@@ -216,20 +202,20 @@
 
     // calculating the sigma values
     arma::mat sigma_temp(dimension, dimension);
-    for (size_t k = 0; k < num_mods; k++) {
+    for (size_t k = 0; k < gaussians; k++) {
       lower_triangle_matrix.zeros();
       for (size_t j = 0; j < dimension; j++) {
         for (size_t i = 0; i < j; i++) {
-          lower_triangle_matrix(j, i) = theta[(num_mods - 1)
-              + (num_mods * dimension) + k * (dimension * (dimension + 1) / 2)
+          lower_triangle_matrix(j, i) = theta[(gaussians - 1)
+              + (gaussians * dimension) + k * (dimension * (dimension + 1) / 2)
               + (j * (j + 1) / 2) + i];
         }
-        lower_triangle_matrix(j, j) = theta[(num_mods - 1)
-            + (num_mods * dimension) + k * (dimension * (dimension + 1) / 2)
+        lower_triangle_matrix(j, j) = theta[(gaussians - 1)
+            + (gaussians * dimension) + k * (dimension * (dimension + 1) / 2)
             + (j * (j + 1) / 2) + j] + s_min;
       }
       sigma_temp = lower_triangle_matrix * trans(lower_triangle_matrix);
-      set_sigma(k, sigma_temp);
+      covariances[k] = sigma_temp;
 
       // calculating the d_sigma values
       for (size_t i = 0; i < dimension; i++) {
@@ -246,39 +232,6 @@
     }
   }
 
-  ////// THE GET FUNCTCLINS //////
-  std::vector<arma::vec>& mu() {
-    return mu_;
-  }
-
-  std::vector<arma::mat>& sigma() {
-    return sigma_;
-  }
-
-  arma::vec& omega() {
-    return omega_;
-  }
-
-  size_t number_of_gaussians() {
-    return number_of_gaussians_;
-  }
-
-  size_t dimension() {
-    return dimension_;
-  }
-
-  arma::vec& mu(size_t i) {
-    return mu_[i] ;
-  }
-
-  arma::mat& sigma(size_t i) {
-    return sigma_[i];
-  }
-
-  double omega(size_t i) {
-    return omega_[i];
-  }
-
   arma::mat& d_omega() {
     return d_omega_;
   }
@@ -291,28 +244,6 @@
     return d_sigma_[i];
   }
 
-  ////// THE SET FUNCTCLINS //////
-
-  void set_mu(size_t i, const arma::vec& mu) {
-    assert(i < number_of_gaussians_);
-    assert(mu.n_elem == dimension_);
-
-    mu_[i] = mu;
-  }
-
-  void set_sigma(size_t i, const arma::mat& sigma) {
-    assert(i < number_of_gaussians_);
-    assert(sigma.n_rows == dimension_);
-    assert(sigma.n_cols == dimension_);
-
-    sigma_[i] = sigma;
-  }
-
-  void set_omega(const arma::vec& omega) {
-    assert(omega.n_elem == number_of_gaussians_);
-    omega_ = omega;
-  }
-
   void set_d_omega(const arma::mat& d_omega) {
     d_omega_ = d_omega;
   }
@@ -333,61 +264,23 @@
   void OutputResults(std::vector<double>& results) {
 
     // Initialize the size of the output array
-    results.resize(number_of_gaussians_ * (1 + dimension_ * (1 + dimension_)));
+    results.resize(gaussians * (1 + dimension * (1 + dimension)));
 
     // Copy values to the array from the private variables of the class
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      results[i] = omega_[i];
-      for (size_t j = 0; j < dimension_; j++) {
-        results[number_of_gaussians_ + (i * dimension_) + j] = (mu_[i])[j];
-          for (size_t k = 0; k < dimension_; k++) {
-            results[number_of_gaussians_ * (1 + dimension_)
-               + (i * dimension_ * dimension_) + (j * dimension_)
-               + k] = (sigma_[i])(j, k);
+    for (size_t i = 0; i < gaussians; i++) {
+      results[i] = weights[i];
+      for (size_t j = 0; j < dimension; j++) {
+        results[gaussians + (i * dimension) + j] = (means[i])[j];
+          for (size_t k = 0; k < dimension; k++) {
+            results[gaussians * (1 + dimension)
+               + (i * dimension * dimension) + (j * dimension)
+               + k] = (covariances[i])(j, k);
         }
       }
     }
   }
 
   /**
-   * This function prints the parameters of the model
-   *
-   * @code
-   * mog.Display();
-   * @endcode
-   */
-  void Display() {
-    // Output the model parameters as the omega, mu and sigma
-    Log::Info << " Omega : [ ";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      Log::Info << omega_[i];
-    }
-    Log::Info << "]" << std::endl;
-    Log::Info << " Mu : " << std::endl << "[";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      for (size_t j = 0; j < dimension_ ; j++) {
-        Log::Info << (mu_[i])[j];
-      }
-      Log::Info << ";";
-      if (i == (number_of_gaussians_ - 1)) {
-        Log::Info << "\b]" << std::endl;
-      }
-    }
-    Log::Info << "Sigma : ";
-    for (size_t i = 0; i < number_of_gaussians_; i++) {
-      Log::Info << std::endl << "[";
-      for (size_t j = 0; j < dimension_ ; j++) {
-        for (size_t k = 0; k < dimension_ ; k++) {
-          Log::Info << (sigma_[i])(j, k);
-        }
-        Log::Info << ";";
-      }
-      Log::Info << "\b]";
-    }
-    Log::Info << std::endl;
-  }
-
-  /**
    * This function calculates the L2 error and
    * the gradient of the error with respect to the
    * parameters given the data and the parameterized
@@ -475,10 +368,10 @@
    */
   static long double L2ErrorForOpt(const arma::vec& params,
                                    const arma::mat& data) {
-    MoGL2E model;
     size_t num_gauss = (params.n_elem + 1) * 2 /
         ((data.n_rows + 1) * (data.n_rows + 2));
 
+    MoGL2E model(num_gauss, data.n_rows);
     model.MakeModel(num_gauss, data.n_rows, params);
 
     return model.L2Error(data);
@@ -488,10 +381,10 @@
                                    const arma::mat& data,
                                    arma::vec& gradient) {
 
-    MoGL2E model;
     size_t num_gauss = (params.n_elem + 1) * 2 /
         ((data.n_rows + 1) * (data.n_rows + 2));
 
+    MoGL2E model(num_gauss, data.n_rows);
     model.MakeModelWithGradients(num_gauss, data.n_rows, params);
 
     return model.L2Error(data, gradient);

Modified: mlpack/trunk/src/mlpack/methods/mog/phi.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mog/phi.hpp	2011-11-08 17:02:18 UTC (rev 10182)
+++ mlpack/trunk/src/mlpack/methods/mog/phi.hpp	2011-11-08 17:13:17 UTC (rev 10183)
@@ -5,67 +5,57 @@
  * This file computes the Gaussian probability
  * density function
  */
+#ifndef __MLPACK_METHODS_MOG_PHI_HPP
+#define __MLPACK_METHODS_MOG_PHI_HPP
+
 #include <mlpack/core.h>
 
 namespace mlpack {
 namespace gmm {
 
 /**
- * Calculates the multivariate Gaussian probability density function
+ * Calculates the univariate Gaussian probability density function
  *
  * Example use:
  * @code
- * Vector x, mean;
- * Matrix cov;
+ * double x, mean, var;
  * ....
- * long double f = phi(x, mean, cov);
+ * long double f = phi(x, mean, var);
  * @endcode
  */
-long double phi(const arma::vec& x,
-                const arma::vec& mean,
-                const arma::mat& cov) {
-  long double cdet, f;
-  double exponent;
-  size_t dim;
-  arma::mat cinv = inv(cov);
-  arma::vec diff, tmp;
-
-  dim = x.n_elem;
-  cdet = det(cov);
-
-  if (cdet < 0)
-    cdet = -cdet;
-
-  diff = mean - x;
-  tmp = cinv * diff;
-  exponent = dot(diff, tmp);
-
-  long double tmp1, tmp2;
-  tmp1 = 1 / pow(2 * M_PI, dim / 2);
-  tmp2 = 1 / sqrt(cdet);
-  f = (tmp1 * tmp2 * exp(-exponent / 2));
-
-  return f;
+inline long double phi(const double x, const double mean, const double var)
+{
+  return exp(-1.0 * ((x - mean) * (x - mean) / (2 * var)))
+      / sqrt(2 * M_PI * var);
 }
 
 /**
- * Calculates the univariate Gaussian probability density function
+ * Calculates the multivariate Gaussian probability density function
  *
  * Example use:
  * @code
- * double x, mean, var;
+ * Vector x, mean;
+ * Matrix cov;
  * ....
- * long double f = phi(x, mean, var);
+ * long double f = phi(x, mean, cov);
  * @endcode
  */
-long double phi(const double x, const double mean, const double var) {
-  return exp(-1.0 * ((x - mean) * (x - mean) / (2 * var)))
-      / sqrt(2 * M_PI * var);
+inline long double phi(const arma::vec& x,
+                       const arma::vec& mean,
+                       const arma::mat& cov)
+{
+  arma::vec diff = mean - x;
+
+  arma::vec exponent = -0.5 * trans(diff) * inv(cov) * diff;
+
+  // TODO: What if det(cov) < 0?
+  return pow(2 * M_PI, (double) x.n_elem / -2.0) * pow(det(cov), -0.5) *
+      exp(exponent[0]);
 }
 
 /**
- * Calculates the multivariate Gaussian probability density function
- * and also the gradients with respect to the mean and the variance
+ * Calculates the multivariate Gaussian probability density function and also
+ * the gradients with respect to the mean and the variance.
  *
  * Example use:
  * @code
@@ -75,55 +65,40 @@
  * long double f = phi(x, mean, cov, d_cov, &g_mean, &g_cov);
  * @endcode
  */
-long double phi(const arma::vec& x,
-                const arma::vec& mean,
-                const arma::mat& cov,
-                const std::vector<arma::mat>& d_cov,
-                arma::vec& g_mean,
-                arma::vec& g_cov) {
-  long double cdet, f;
-  double exponent;
-  size_t dim;
+inline long double phi(const arma::vec& x,
+                       const arma::vec& mean,
+                       const arma::mat& cov,
+                       const std::vector<arma::mat>& d_cov,
+                       arma::vec& g_mean,
+                       arma::vec& g_cov)
+{
+  // We don't call out to another version of the function to avoid inverting the
+  // covariance matrix more than once.
   arma::mat cinv = inv(cov);
-  arma::vec diff, tmp;
 
-  dim = x.n_elem;
-  cdet = det(cov);
+  arma::vec diff = mean - x;
+  arma::vec exponent = -0.5 * trans(diff) * inv(cov) * diff;
 
-  if (cdet < 0)
-    cdet = -cdet;
+  long double f = pow(2 * M_PI, (double) x.n_elem / 2) * pow(det(cov), -0.5)
+      * exp(exponent[0]);
 
-  diff = mean - x;
-  tmp = cinv * diff;
-  exponent = dot(diff, tmp);
+  // Calculate the g_mean values; this is a (1 x dim) vector.
+  arma::vec invDiff = cinv * diff;
+  g_mean = f * invDiff;
 
-  long double tmp1, tmp2;
-  tmp1 = 1 / pow(2 * M_PI, dim / 2);
-  tmp2 = 1 / sqrt(cdet);
-  f = (tmp1 * tmp2 * exp(-exponent / 2));
+  // Calculate the g_cov values; this is a (1 x (dim * (dim + 1) / 2)) vector.
+  for (size_t i = 0; i < d_cov.size(); i++)
+  {
+    arma::mat inv_d = cinv * d_cov[i];
 
-  // Calculating the g_mean values  which would be a (1 X dim) vector
-  g_mean = f * tmp;
-
-  // Calculating the g_cov values which would be a (1 X (dim*(dim+1)/2)) vector
-  arma::vec g_cov_tmp(d_cov.size());
-  for (size_t i = 0; i < d_cov.size(); i++) {
-    arma::vec tmp_d;
-    arma::mat inv_d;
-    long double tmp_d_cov_d_r;
-
-    tmp_d = d_cov[i] * tmp;
-    tmp_d_cov_d_r = dot(tmp_d,tmp);
-    inv_d = cinv * d_cov[i];
-    for (size_t j = 0; j < dim; j++)
-      tmp_d_cov_d_r += inv_d(j, j);
-    g_cov_tmp[i] = f * tmp_d_cov_d_r / 2;
+    g_cov[i] = f * dot(d_cov[i] * invDiff, invDiff) +
+        accu((cinv * d_cov[i]).diag()) / 2;
   }
 
-  g_cov = g_cov_tmp;
-
   return f;
 }
 
 }; // namespace gmm
 }; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list