[mlpack-svn] r15009 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri May 3 18:54:34 EDT 2013


Author: rcurtin
Date: 2013-05-03 18:54:33 -0400 (Fri, 03 May 2013)
New Revision: 15009

Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
Log:
Avoid inverting empty matrices.  Also fix a possible uninitialized memory issue.


Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	2013-05-03 22:35:21 UTC (rev 15008)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	2013-05-03 22:54:33 UTC (rev 15009)
@@ -25,7 +25,7 @@
 
   // Update the variables according to the number of features and classes
   // present in the data.
-  probabilities.set_size(classes);
+  probabilities.zeros(classes);
   means.zeros(dimensionality, classes);
   variances.zeros(dimensionality, classes);
 
@@ -45,9 +45,12 @@
 
   for (size_t i = 0; i < classes; ++i)
   {
-    variances.col(i) -= (square(means.col(i)) / probabilities[i]);
-    means.col(i) /= probabilities[i];
-    variances.col(i) /= (probabilities[i] - 1);
+    if (probabilities[i] != 0)
+    {
+      variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+      means.col(i) /= probabilities[i];
+      variances.col(i) /= (probabilities[i] - 1);
+    }
   }
 
   probabilities /= data.n_cols;
@@ -80,9 +83,12 @@
       // Use the log values to prevent floating point underflow.
       probs(i) = log(probabilities(i));
 
-      // Loop over every feature.
-      probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
-          diagmat(variances.unsafe_col(i))));
+      // Loop over every feature, but avoid inverting empty matrices.
+      if (probabilities[i] != 0)
+      {
+        probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
+            diagmat(variances.unsafe_col(i))));
+      }
     }
 
     // Find the index of the maximum value in tmp_vals.




More information about the mlpack-svn mailing list