[mlpack-git] master, mlpack-1.0.x: Adapted patch from Vahab for #344: incremental algorithm for variance calculation in Naive Bayes classifier. This is optional and can be specified in the NaiveBayesClassifier constructor. (cf6bff8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:46:23 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit cf6bff87f85515109830c4743528fa37ff98eefa
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Apr 15 15:23:52 2014 +0000

    Adapted patch from Vahab for #344: incremental algorithm for variance
    calculation in Naive Bayes classifier.  This is optional and can be specified in
    the NaiveBayesClassifier constructor.


>---------------------------------------------------------------

cf6bff87f85515109830c4743528fa37ff98eefa
 .../methods/naive_bayes/naive_bayes_classifier.hpp |  6 ++-
 .../naive_bayes/naive_bayes_classifier_impl.hpp    | 58 +++++++++++++++-------
 src/mlpack/methods/naive_bayes/nbc_main.cpp        | 17 +++++--
 3 files changed, 59 insertions(+), 22 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 188ed75..9addfc6 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -69,10 +69,14 @@ class NaiveBayesClassifier
    * @param data Training data points.
    * @param labels Labels corresponding to training data points.
    * @param classes Number of classes in this classifier.
+   * @param incrementalVariance If true, an incremental algorithm is used to
+   *     calculate the variance; this can prevent loss of precision in some
+   *     cases, but will be somewhat slower to calculate.
    */
   NaiveBayesClassifier(const MatType& data,
                        const arma::Col<size_t>& labels,
-                       const size_t classes);
+                       const size_t classes,
+                       const bool incrementalVariance = false);
 
   /**
    * Given a bunch of data points, this function evaluates the class of each of
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 08791ee..2fd92f4 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -22,9 +22,10 @@ template<typename MatType>
 NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
     const MatType& data,
     const arma::Col<size_t>& labels,
-    const size_t classes)
+    const size_t classes,
+    const bool incrementalVariance)
 {
-  size_t dimensionality = data.n_rows;
+  const size_t dimensionality = data.n_rows;
 
   // Update the variables according to the number of features and classes
   // present in the data.
@@ -37,30 +38,53 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
 
   // Calculate the class probabilities as well as the sample mean and variance
   // for each of the features with respect to each of the labels.
-  for (size_t j = 0; j < data.n_cols; ++j)
+  if (incrementalVariance)
   {
-    const size_t label = labels[j];
-    ++probabilities[label];
+    // Use incremental algorithm.
+    for (size_t j = 0; j < data.n_cols; ++j)
+    {
+      const size_t label = labels[j];
+      ++probabilities[label];
 
-    means.col(label) += data.col(j);
-    variances.col(label) += square(data.col(j));
-  }
+      arma::vec delta = data.col(j) - means.col(label);
+      means.col(label) += delta / probabilities[label];
+      variances.col(label) += delta % (data.col(j) - means.col(label));
+    }
 
-  for (size_t i = 0; i < classes; ++i)
+    for (size_t i = 0; i < classes; ++i)
+    {
+      if (probabilities[i] > 2)
+        variances.col(i) /= (probabilities[i] - 1);
+    }
+  }
+  else
   {
-    if (probabilities[i] != 0)
+    // Don't use incremental algorithm.
+    for (size_t j = 0; j < data.n_cols; ++j)
     {
-      variances.col(i) -= (square(means.col(i)) / probabilities[i]);
-      means.col(i) /= probabilities[i];
-      variances.col(i) /= (probabilities[i] - 1);
+      const size_t label = labels[j];
+      ++probabilities[label];
+
+      means.col(label) += data.col(j);
+      variances.col(label) += square(data.col(j));
     }
 
-    // Make sure variance is invertible.
-    for (size_t j = 0; j < dimensionality; ++j)
-      if (variances(j, i) == 0.0)
-        variances(j, i) = 1e-50;
+    for (size_t i = 0; i < classes; ++i)
+    {
+      if (probabilities[i] != 0)
+      {
+        variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+        means.col(i) /= probabilities[i];
+        variances.col(i) /= (probabilities[i] - 1);
+      }
+    }
   }
 
+  // Ensure that the variances are invertible.
+  for (size_t i = 0; i < variances.n_elem; ++i)
+    if (variances[i] == 0.0)
+      variances[i] = 1e-50;
+
   probabilities /= data.n_cols;
 }
 
diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index 423812a..391886c 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -14,11 +14,15 @@
 PROGRAM_INFO("Parametric Naive Bayes Classifier",
     "This program trains the Naive Bayes classifier on the given labeled "
     "training set and then uses the trained classifier to classify the points "
-    "in the given test set.\n"
-    "\n"
+    "in the given test set."
+    "\n\n"
     "Labels are expected to be the last row of the training set (--train_file),"
     " but labels can also be passed in separately as their own file "
-    "(--labels_file).");
+    "(--labels_file)."
+    "\n\n"
+    "The '--incremental_variance' option can be used to force the training to "
+    "use an incremental algorithm for calculating variance.  This is slower, "
+    "but can help avoid loss of precision in some cases.");
 
 PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
 PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
@@ -27,6 +31,8 @@ PARAM_STRING("labels_file", "A file containing labels for the training set.",
     "l", "");
 PARAM_STRING("output", "The file in which the predicted labels for the test set"
     " will be written.", "o", "output.csv");
+PARAM_FLAG("incremental_variance", "The variance of each class will be "
+    "calculated incrementally.", "I");
 
 using namespace mlpack;
 using namespace mlpack::naive_bayes;
@@ -80,9 +86,12 @@ int main(int argc, char* argv[])
         << "must be the same as training data (" << trainingData.n_rows - 1
         << ")!" << std::endl;
 
+  const bool incrementalVariance = CLI::HasParam("incremental_variance");
+
   // Create and train the classifier.
   Timer::Start("training");
-  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem);
+  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
+      incrementalVariance);
   Timer::Stop("training");
 
   // Time the running of the Naive Bayes Classifier.



More information about the mlpack-git mailing list