[mlpack-svn] r16426 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 14 16:11:16 EDT 2014


Author: rcurtin
Date: Mon Apr 14 16:11:16 2014
New Revision: 16426

Log:
Overhaul implementation; do not use gmm::phi().  This gives serious speedup, as
high-dimensional matrix inverses are not being calculated.  The previous calls
to gmm::phi() would invert a diagonal matrix without being able to assume that
the matrix was diagonal.  This explains the very very poor benchmarking results
for nbc in mlpack.


Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	Mon Apr 14 16:11:16 2014
@@ -1,10 +1,11 @@
 /**
- * @file simple_nbc_impl.hpp
+ * @file naive_bayes_classifier_impl.hpp
  * @author Parikshit Ram (pram at cc.gatech.edu)
  *
  * A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features.  It is assumed that the features have been sampled from a
- * Gaussian PDF.
+ * the features.  This classifier makes its predictions based on the assumption
+ * that the features have been sampled from a set of Gaussians with diagonal
+ * covariance.
  */
 #ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
 #define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
@@ -53,6 +54,11 @@
       means.col(i) /= probabilities[i];
       variances.col(i) /= (probabilities[i] - 1);
     }
+
+    // Make sure variance is invertible.
+    for (size_t j = 0; j < dimensionality; ++j)
+      if (variances(j, i) == 0.0)
+        variances(j, i) = 1e-50;
   }
 
   probabilities /= data.n_cols;
@@ -66,9 +72,12 @@
   // training data.
   Log::Assert(data.n_rows == means.n_rows);
 
-  arma::vec probs(means.n_cols);
+  arma::vec probs = arma::log(probabilities);
+  arma::mat invVar = 1.0 / variances;
+
+  arma::mat testProbs = arma::repmat(probs.t(), data.n_cols, 1);
 
-  results.zeros(data.n_cols);
+  results.set_size(data.n_cols); // No need to fill with anything yet.
 
   Log::Info << "Running Naive Bayes classifier on " << data.n_cols
       << " data points with " << data.n_rows << " features each." << std::endl;
@@ -76,28 +85,30 @@
   // Calculate the joint probability for each of the data points for each of the
   // means.n_cols.
 
-  // Loop over every test case.
-  for (size_t n = 0; n < data.n_cols; n++)
+  // Loop over every class.
+  for (size_t i = 0; i < means.n_cols; i++)
   {
-    // Loop over every class.
-    for (size_t i = 0; i < means.n_cols; i++)
-    {
-      // Use the log values to prevent floating point underflow.
-      probs(i) = log(probabilities(i));
+    // This is an adaptation of gmm::phi() for the case where the covariance is
+    // a diagonal matrix.
+    arma::mat diffs = data - arma::repmat(means.col(i), 1, data.n_cols);
+    arma::mat rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs;
+    arma::vec exponents(diffs.n_cols);
+    for (size_t j = 0; j < diffs.n_cols; ++j)
+      exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j)));
 
-      // Loop over every feature, but avoid inverting empty matrices.
-      if (probabilities[i] != 0)
-      {
-        probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
-            diagmat(variances.unsafe_col(i))));
-      }
-    }
+    testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) *
+        pow(det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
+  }
 
-    // Find the index of the maximum value in tmp_vals.
+  // Now calculate the label.
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    // Find the index of the class with maximum probability for this point.
     arma::uword maxIndex = 0;
-    probs.max(maxIndex);
+    arma::vec pointProbs = testProbs.row(i).t();
+    pointProbs.max(maxIndex);
 
-    results[n] = maxIndex;
+    results[i] = maxIndex;
   }
 
   return;



More information about the mlpack-svn mailing list