[mlpack-svn] r16427 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Apr 15 11:23:52 EDT 2014


Author: rcurtin
Date: Tue Apr 15 11:23:52 2014
New Revision: 16427

Log:
Adapted patch from Vahab for #344: incremental algorithm for variance
calculation in Naive Bayes classifier.  This is optional and can be specified in
the NaiveBayesClassifier constructor.


Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	Tue Apr 15 11:23:52 2014
@@ -69,10 +69,14 @@
    * @param data Training data points.
    * @param labels Labels corresponding to training data points.
    * @param classes Number of classes in this classifier.
+   * @param incrementalVariance If true, an incremental algorithm is used to
+   *     calculate the variance; this can prevent loss of precision in some
+   *     cases, but will be somewhat slower to calculate.
    */
   NaiveBayesClassifier(const MatType& data,
                        const arma::Col<size_t>& labels,
-                       const size_t classes);
+                       const size_t classes,
+                       const bool incrementalVariance = false);
 
   /**
    * Given a bunch of data points, this function evaluates the class of each of

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	Tue Apr 15 11:23:52 2014
@@ -22,9 +22,10 @@
 NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
     const MatType& data,
     const arma::Col<size_t>& labels,
-    const size_t classes)
+    const size_t classes,
+    const bool incrementalVariance)
 {
-  size_t dimensionality = data.n_rows;
+  const size_t dimensionality = data.n_rows;
 
   // Update the variables according to the number of features and classes
   // present in the data.
@@ -37,30 +38,53 @@
 
   // Calculate the class probabilities as well as the sample mean and variance
   // for each of the features with respect to each of the labels.
-  for (size_t j = 0; j < data.n_cols; ++j)
+  if (incrementalVariance)
   {
-    const size_t label = labels[j];
-    ++probabilities[label];
+    // Use incremental algorithm.
+    for (size_t j = 0; j < data.n_cols; ++j)
+    {
+      const size_t label = labels[j];
+      ++probabilities[label];
 
-    means.col(label) += data.col(j);
-    variances.col(label) += square(data.col(j));
-  }
+      arma::vec delta = data.col(j) - means.col(label);
+      means.col(label) += delta / probabilities[label];
+      variances.col(label) += delta % (data.col(j) - means.col(label));
+    }
 
-  for (size_t i = 0; i < classes; ++i)
+    for (size_t i = 0; i < classes; ++i)
+    {
+      if (probabilities[i] > 2)
+        variances.col(i) /= (probabilities[i] - 1);
+    }
+  }
+  else
   {
-    if (probabilities[i] != 0)
+    // Don't use incremental algorithm.
+    for (size_t j = 0; j < data.n_cols; ++j)
     {
-      variances.col(i) -= (square(means.col(i)) / probabilities[i]);
-      means.col(i) /= probabilities[i];
-      variances.col(i) /= (probabilities[i] - 1);
+      const size_t label = labels[j];
+      ++probabilities[label];
+
+      means.col(label) += data.col(j);
+      variances.col(label) += square(data.col(j));
     }
 
-    // Make sure variance is invertible.
-    for (size_t j = 0; j < dimensionality; ++j)
-      if (variances(j, i) == 0.0)
-        variances(j, i) = 1e-50;
+    for (size_t i = 0; i < classes; ++i)
+    {
+      if (probabilities[i] != 0)
+      {
+        variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+        means.col(i) /= probabilities[i];
+        variances.col(i) /= (probabilities[i] - 1);
+      }
+    }
   }
 
+  // Ensure that the variances are invertible.
+  for (size_t i = 0; i < variances.n_elem; ++i)
+    if (variances[i] == 0.0)
+      variances[i] = 1e-50;
+
   probabilities /= data.n_cols;
 }
 

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	Tue Apr 15 11:23:52 2014
@@ -14,11 +14,15 @@
 PROGRAM_INFO("Parametric Naive Bayes Classifier",
     "This program trains the Naive Bayes classifier on the given labeled "
     "training set and then uses the trained classifier to classify the points "
-    "in the given test set.\n"
-    "\n"
+    "in the given test set."
+    "\n\n"
     "Labels are expected to be the last row of the training set (--train_file),"
     " but labels can also be passed in separately as their own file "
-    "(--labels_file).");
+    "(--labels_file)."
+    "\n\n"
+    "The '--incremental_variance' option can be used to force the training to "
+    "use an incremental algorithm for calculating variance.  This is slower, "
+    "but can help avoid loss of precision in some cases.");
 
 PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
 PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
@@ -27,6 +31,8 @@
     "l", "");
 PARAM_STRING("output", "The file in which the predicted labels for the test set"
     " will be written.", "o", "output.csv");
+PARAM_FLAG("incremental_variance", "The variance of each class will be "
+    "calculated incrementally.", "I");
 
 using namespace mlpack;
 using namespace mlpack::naive_bayes;
@@ -80,9 +86,12 @@
         << "must be the same as training data (" << trainingData.n_rows - 1
         << ")!" << std::endl;
 
+  const bool incrementalVariance = CLI::HasParam("incremental_variance");
+
   // Create and train the classifier.
   Timer::Start("training");
-  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem);
+  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
+      incrementalVariance);
   Timer::Stop("training");
 
   // Time the running of the Naive Bayes Classifier.



More information about the mlpack-svn mailing list