[mlpack-git] master, mlpack-1.0.x: Adapted patch from Vahab for #344: incremental algorithm for variance calculation in Naive Bayes classifier. This is optional and can be specified in the NaiveBayesClassifier constructor. (cf6bff8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:46:23 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit cf6bff87f85515109830c4743528fa37ff98eefa
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Apr 15 15:23:52 2014 +0000
Adapted patch from Vahab for #344: incremental algorithm for variance
calculation in Naive Bayes classifier. This is optional and can be specified in
the NaiveBayesClassifier constructor.
>---------------------------------------------------------------
cf6bff87f85515109830c4743528fa37ff98eefa
.../methods/naive_bayes/naive_bayes_classifier.hpp | 6 ++-
.../naive_bayes/naive_bayes_classifier_impl.hpp | 58 +++++++++++++++-------
src/mlpack/methods/naive_bayes/nbc_main.cpp | 17 +++++--
3 files changed, 59 insertions(+), 22 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 188ed75..9addfc6 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -69,10 +69,14 @@ class NaiveBayesClassifier
* @param data Training data points.
* @param labels Labels corresponding to training data points.
* @param classes Number of classes in this classifier.
+ * @param incrementalVariance If true, an incremental algorithm is used to
+ * calculate the variance; this can prevent loss of precision in some
+ * cases, but will be somewhat slower to calculate.
*/
NaiveBayesClassifier(const MatType& data,
const arma::Col<size_t>& labels,
- const size_t classes);
+ const size_t classes,
+ const bool incrementalVariance = false);
/**
* Given a bunch of data points, this function evaluates the class of each of
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 08791ee..2fd92f4 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -22,9 +22,10 @@ template<typename MatType>
NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
const MatType& data,
const arma::Col<size_t>& labels,
- const size_t classes)
+ const size_t classes,
+ const bool incrementalVariance)
{
- size_t dimensionality = data.n_rows;
+ const size_t dimensionality = data.n_rows;
// Update the variables according to the number of features and classes
// present in the data.
@@ -37,30 +38,53 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
// Calculate the class probabilities as well as the sample mean and variance
// for each of the features with respect to each of the labels.
- for (size_t j = 0; j < data.n_cols; ++j)
+ if (incrementalVariance)
{
- const size_t label = labels[j];
- ++probabilities[label];
+ // Use incremental algorithm.
+ for (size_t j = 0; j < data.n_cols; ++j)
+ {
+ const size_t label = labels[j];
+ ++probabilities[label];
- means.col(label) += data.col(j);
- variances.col(label) += square(data.col(j));
- }
+ arma::vec delta = data.col(j) - means.col(label);
+ means.col(label) += delta / probabilities[label];
+ variances.col(label) += delta % (data.col(j) - means.col(label));
+ }
- for (size_t i = 0; i < classes; ++i)
+ for (size_t i = 0; i < classes; ++i)
+ {
+ if (probabilities[i] > 2)
+ variances.col(i) /= (probabilities[i] - 1);
+ }
+ }
+ else
{
- if (probabilities[i] != 0)
+ // Don't use incremental algorithm.
+ for (size_t j = 0; j < data.n_cols; ++j)
{
- variances.col(i) -= (square(means.col(i)) / probabilities[i]);
- means.col(i) /= probabilities[i];
- variances.col(i) /= (probabilities[i] - 1);
+ const size_t label = labels[j];
+ ++probabilities[label];
+
+ means.col(label) += data.col(j);
+ variances.col(label) += square(data.col(j));
}
- // Make sure variance is invertible.
- for (size_t j = 0; j < dimensionality; ++j)
- if (variances(j, i) == 0.0)
- variances(j, i) = 1e-50;
+ for (size_t i = 0; i < classes; ++i)
+ {
+ if (probabilities[i] != 0)
+ {
+ variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+ means.col(i) /= probabilities[i];
+ variances.col(i) /= (probabilities[i] - 1);
+ }
+ }
}
+ // Ensure that the variances are invertible.
+ for (size_t i = 0; i < variances.n_elem; ++i)
+ if (variances[i] == 0.0)
+ variances[i] = 1e-50;
+
probabilities /= data.n_cols;
}
diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index 423812a..391886c 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -14,11 +14,15 @@
PROGRAM_INFO("Parametric Naive Bayes Classifier",
"This program trains the Naive Bayes classifier on the given labeled "
"training set and then uses the trained classifier to classify the points "
- "in the given test set.\n"
- "\n"
+ "in the given test set."
+ "\n\n"
"Labels are expected to be the last row of the training set (--train_file),"
" but labels can also be passed in separately as their own file "
- "(--labels_file).");
+ "(--labels_file)."
+ "\n\n"
+ "The '--incremental_variance' option can be used to force the training to "
+ "use an incremental algorithm for calculating variance. This is slower, "
+ "but can help avoid loss of precision in some cases.");
PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
@@ -27,6 +31,8 @@ PARAM_STRING("labels_file", "A file containing labels for the training set.",
"l", "");
PARAM_STRING("output", "The file in which the predicted labels for the test set"
" will be written.", "o", "output.csv");
+PARAM_FLAG("incremental_variance", "The variance of each class will be "
+ "calculated incrementally.", "I");
using namespace mlpack;
using namespace mlpack::naive_bayes;
@@ -80,9 +86,12 @@ int main(int argc, char* argv[])
<< "must be the same as training data (" << trainingData.n_rows - 1
<< ")!" << std::endl;
+ const bool incrementalVariance = CLI::HasParam("incremental_variance");
+
// Create and train the classifier.
Timer::Start("training");
- NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem);
+ NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
+ incrementalVariance);
Timer::Stop("training");
// Time the running of the Naive Bayes Classifier.
More information about the mlpack-git
mailing list