[mlpack-svn] r16430 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 16 14:18:13 EDT 2014


Author: rcurtin
Date: Wed Apr 16 14:18:13 2014
New Revision: 16430

Log:
Change to two-pass algorithm suggested by Vahab in #344.


Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	Wed Apr 16 14:18:13 2014
@@ -1,6 +1,7 @@
 /**
  * @file naive_bayes_classifier_impl.hpp
  * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Vahab Akbarzadeh (v.akbarzadeh at gmail.com)
  *
  * A Naive Bayes Classifier which parametrically estimates the distribution of
  * the features.  This classifier makes its predictions based on the assumption
@@ -59,25 +60,36 @@
   }
   else
   {
-    // Don't use incremental algorithm.
+    // Don't use incremental algorithm.  This is a two-pass algorithm.  It is
+    // possible to calculate the means and variances using a faster one-pass
+    // algorithm but there are some precision and stability issues.  If this is
+    // too slow, it's an option to use the faster algorithm by default and then
+    // have this (and the incremental algorithm) be other options.
+
+    // Calculate the means.
     for (size_t j = 0; j < data.n_cols; ++j)
     {
       const size_t label = labels[j];
       ++probabilities[label];
-
       means.col(label) += data.col(j);
-      variances.col(label) += square(data.col(j));
     }
 
+    // Normalize means.
     for (size_t i = 0; i < classes; ++i)
-    {
-      if (probabilities[i] != 0)
-      {
-        variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+      if (probabilities[i] != 0.0)
         means.col(i) /= probabilities[i];
-        variances.col(i) /= (probabilities[i] - 1);
-      }
+
+    // Calculate variances.
+    for (size_t j = 0; j < data.n_cols; ++j)
+    {
+      const size_t label = labels[j];
+      variances.col(label) += square(data.col(j) - means.col(label));
     }
+
+    // Normalize variances.
+    for (size_t i = 0; i < classes; ++i)
+      if (probabilities[i] > 1)
+        variances.col(i) /= (probabilities[i] - 1);
   }
 
   // Ensure that the variances are invertible.



More information about the mlpack-svn mailing list