[mlpack-svn] r16430 - mlpack/trunk/src/mlpack/methods/naive_bayes
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 16 14:18:13 EDT 2014
Author: rcurtin
Date: Wed Apr 16 14:18:13 2014
New Revision: 16430
Log:
Change to two-pass algorithm suggested by Vahab in #344.
Modified:
mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp Wed Apr 16 14:18:13 2014
@@ -1,6 +1,7 @@
/**
* @file naive_bayes_classifier_impl.hpp
* @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Vahab Akbarzadeh (v.akbarzadeh at gmail.com)
*
* A Naive Bayes Classifier which parametrically estimates the distribution of
* the features. This classifier makes its predictions based on the assumption
@@ -59,25 +60,36 @@
}
else
{
- // Don't use incremental algorithm.
+ // Don't use incremental algorithm. This is a two-pass algorithm. It is
+ // possible to calculate the means and variances using a faster one-pass
+ // algorithm but there are some precision and stability issues. If this is
+ // too slow, it's an option to use the faster algorithm by default and then
+ // have this (and the incremental algorithm) be other options.
+
+ // Calculate the means.
for (size_t j = 0; j < data.n_cols; ++j)
{
const size_t label = labels[j];
++probabilities[label];
-
means.col(label) += data.col(j);
- variances.col(label) += square(data.col(j));
}
+ // Normalize means.
for (size_t i = 0; i < classes; ++i)
- {
- if (probabilities[i] != 0)
- {
- variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+ if (probabilities[i] != 0.0)
means.col(i) /= probabilities[i];
- variances.col(i) /= (probabilities[i] - 1);
- }
+
+ // Calculate variances.
+ for (size_t j = 0; j < data.n_cols; ++j)
+ {
+ const size_t label = labels[j];
+ variances.col(label) += square(data.col(j) - means.col(label));
}
+
+ // Normalize variances.
+ for (size_t i = 0; i < classes; ++i)
+ if (probabilities[i] > 1)
+ variances.col(i) /= (probabilities[i] - 1);
}
// Ensure that the variances are invertible.
More information about the mlpack-svn
mailing list