[mlpack-svn] r10755 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 13 14:35:00 EST 2011


Author: rcurtin
Date: 2011-12-13 14:34:59 -0500 (Tue, 13 Dec 2011)
New Revision: 10755

Added:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
Log:
Abstract-ize NBC to allow sparse matrices.


Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt	2011-12-13 19:08:34 UTC (rev 10754)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt	2011-12-13 19:34:59 UTC (rev 10755)
@@ -3,8 +3,8 @@
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  simple_nbc.hpp
-  simple_nbc.cpp
+  naive_bayes_classifier.hpp
+  naive_bayes_classifier_impl.hpp
 )
 
 # Add directory name to sources.

Copied: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp (from rev 10737, mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	2011-12-13 19:34:59 UTC (rev 10755)
@@ -0,0 +1,113 @@
+/**
+ * @file naive_bayes_classifier.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features.  It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ */
+#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
+#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/gmm/phi.hpp>
+
+namespace mlpack {
+namespace naive_bayes {
+
+/**
+ * A classification class. The class labels are assumed
+ * to be positive integers - 0,1,2,....
+ *
+ * This class trains on the data by calculating the
+ * sample mean and variance of the features with
+ * respect to each of the labels, and also the class
+ * probabilities.
+ *
+ * Mathematically, it computes P(X_i = x_i | Y = y_j)
+ * for each feature X_i for each of the labels y_j.
+ * Alongwith this, it also computes the classs probabilities
+ * P( Y = y_j)
+ *
+ * For classifying a data point (x_1, x_2, ..., x_n),
+ * it computes the following:
+ * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
+ *
+ * Example use:
+ *
+ * @code
+ * NaiveBayesClassifier nbc;
+ * arma::mat training_data, testing_data;
+ * datanode *nbc_module = fx_submodule(NULL,"nbc","nbc");
+ * arma::vec results;
+ *
+ * nbc.InitTrain(training_data, nbc_module);
+ * nbc.Classify(testing_data, &results);
+ * @endcode
+ */
+template<typename MatType = arma::mat>
+class NaiveBayesClassifier
+{
+ private:
+  //! Sample mean for each class.
+  MatType means;
+
+  //! Sample variances for each class.
+  MatType variances;
+
+  //! Class probabilities.
+  arma::vec probabilities;
+
+ public:
+  /**
+   * Initializes the classifier as per the input and then trains it
+   * by calculating the sample mean and variances
+   *
+   * Example use:
+   * @code
+   * arma::mat training_data, testing_data;
+   * datanode nbc_module = fx_submodule(NULL,"nbc","nbc");
+   * ....
+   * NaiveBayesClassifier nbc(training_data, nbc_module);
+   * @endcode
+   */
+  NaiveBayesClassifier(const MatType& data, const size_t classes);
+
+  ~NaiveBayesClassifier() { }
+
+  /**
+   * Given a bunch of data points, this function evaluates the class
+   * of each of those data points, and puts it in the vector 'results'
+   *
+   * @code
+   * arma::mat test_data; // each column is a test point
+   * arma::vec results;
+   * ...
+   * nbc.Classify(test_data, &results);
+   * @endcode
+   */
+  void Classify(const MatType& data, arma::Col<size_t>& results);
+
+  //! Get the sample means for each class.
+  const MatType& Means() const { return means; }
+  //! Modify the sample means for each class.
+  MatType& Means() { return means; }
+
+  //! Get the sample variances for each class.
+  const MatType& Variances() const { return variances; }
+  //! Modify the sample variances for each class.
+  MatType& Variances() { return variances; }
+
+  //! Get the prior probabilities for each class.
+  const arma::vec& Probabilities() const { return probabilities; }
+  //! Modify the prior probabilities for each class.
+  arma::vec& Probabilities() { return probabilities; }
+};
+
+}; // namespace naive_bayes
+}; // namespace mlpack
+
+// Include implementation.
+#include "naive_bayes_classifier_impl.hpp"
+
+#endif

Copied: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp (from rev 10737, mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	2011-12-13 19:34:59 UTC (rev 10755)
@@ -0,0 +1,101 @@
+/**
+ * @file simple_nbc_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features.  It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ */
+#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
+#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+// In case it hasn't been included already.
+#include "naive_bayes_classifier.hpp"
+
+namespace mlpack {
+namespace naive_bayes {
+
+template<typename MatType>
+NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const MatType& data,
+                                                    const size_t classes)
+{
+  size_t dimensionality = data.n_rows - 1;
+
+  // Update the variables according to the number of features and classes
+  // present in the data.
+  probabilities.set_size(classes);
+  means.zeros(dimensionality, classes);
+  variances.zeros(dimensionality, classes);
+
+  Log::Info << "Training Naive Bayes classifier on " << data.n_cols
+      << " examples with " << dimensionality << " features each." << std::endl;
+
+  // Calculate the class probabilities as well as the sample mean and variance
+  // for each of the features with respect to each of the labels.
+  for (size_t j = 0; j < data.n_cols; ++j)
+  {
+    size_t label = (size_t) data(dimensionality, j);
+    ++probabilities[label];
+
+    means.col(label) += data(arma::span(0, dimensionality - 1), j);
+    variances.col(label) += square(data(arma::span(0, dimensionality - 1), j));
+  }
+
+  for (size_t i = 0; i < classes; ++i)
+  {
+    variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+    means.col(i) /= probabilities[i];
+    variances.col(i) /= (probabilities[i] - 1);
+  }
+
+  probabilities /= data.n_cols;
+}
+
+template<typename MatType>
+void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
+                                             arma::Col<size_t>& results)
+{
+  // Check that the number of features in the test data is same as in the
+  // training data.
+  Log::Assert(data.n_rows == means.n_rows);
+
+  arma::vec probs(means.n_cols);
+
+  results.zeros(data.n_cols);
+
+  Log::Info << "Running Naive Bayes classifier on " << data.n_cols
+      << " data points with " << data.n_rows << " features each." << std::endl;
+
+  // Calculate the joint probability for each of the data points for each of the
+  // means.n_cols.
+
+  // Loop over every test case.
+  for (size_t n = 0; n < data.n_cols; n++)
+  {
+    // Loop over every class.
+    for (size_t i = 0; i < means.n_cols; i++)
+    {
+      // Use the log values to prevent floating point underflow.
+      probs(i) = log(probabilities(i));
+
+      // Loop over every feature.
+      probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
+          diagmat(variances.unsafe_col(i))));
+    }
+
+    // Find the index of the maximum value in tmp_vals.
+    size_t maxIndex = 0;
+    probs.max((arma::uword&) maxIndex);
+
+    results[n] = maxIndex;
+  }
+
+  return;
+}
+
+}; // namespace naive_bayes
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	2011-12-13 19:08:34 UTC (rev 10754)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	2011-12-13 19:34:59 UTC (rev 10755)
@@ -16,7 +16,7 @@
  * This is the number of classes present in the training data.
  *
  * --test
- * This file contains the data points which the trained classifier would
+ * This file contains the data points which the trained classifier should
  * classify.
  *
  * --output
@@ -25,14 +25,14 @@
  */
 #include <mlpack/core.hpp>
 
-#include "simple_nbc.hpp"
+#include "naive_bayes_classifier.hpp"
 
-PARAM_INT_REQ("classes", "The number of classes present in the data.", "C")
+PARAM_INT_REQ("classes", "The number of classes present in the data.", "c")
 
-PARAM_STRING_REQ("train", "A file containing the training set", "R");
+PARAM_STRING_REQ("train", "A file containing the training set", "t");
 PARAM_STRING_REQ("test", "A file containing the test set", "T");
 PARAM_STRING("output", "The file in which the output of the test would "
-    "be written, defaults to 'output.csv')", "O", "output.csv");
+    "be written, defaults to 'output.csv')", "o", "output.csv");
 
 PROGRAM_INFO("Parametric Naive Bayes", "This program test drives the Parametric"
     " Naive Bayes Classifier assuming that the features are sampled from a "
@@ -59,7 +59,7 @@
 
   // Create and train the classifier.
   Timers::StartTimer("training");
-  SimpleNaiveBayesClassifier nbc = SimpleNaiveBayesClassifier(training_data,
+  NaiveBayesClassifier nbc = NaiveBayesClassifier(training_data,
       number_of_classes_);
   Timers::StopTimer("training");
 

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp	2011-12-13 19:08:34 UTC (rev 10754)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp	2011-12-13 19:34:59 UTC (rev 10755)
@@ -1,123 +0,0 @@
-/**
- * @file simple_nbc.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features.  It is assumed that the features have been sampled from a
- * Gaussian PDF.
- */
-#include <mlpack/core.hpp>
-
-#include "simple_nbc.hpp"
-
-namespace mlpack {
-namespace naive_bayes {
-
-SimpleNaiveBayesClassifier::SimpleNaiveBayesClassifier(const arma::mat& data, 
-    size_t classes) : number_of_classes_(classes)
-{
-  size_t number_examples = data.n_cols;
-  size_t number_features = data.n_rows - 1;
-
-  arma::vec feature_sum, feature_sum_squared;
-  feature_sum.zeros(number_features);
-  feature_sum_squared.zeros(number_features);
-
-  // Update the variables, private and local, according to the number of
-  // features and classes present in the data.
-  class_probabilities_.set_size(number_of_classes_);
-  means_.set_size(number_features,number_of_classes_);
-  variances_.set_size(number_features,number_of_classes_);
-
-  Log::Info << number_examples << " examples with " << number_features
-      << " features each" << std::endl;
-
-  // Calculate the class probabilities as well as the sample mean and variance
-  // for each of the features with respect to each of the labels.
-  for (size_t i = 0; i < number_of_classes_; i++ )
-  {
-    size_t number_of_occurrences = 0;
-    for (size_t j = 0; j < number_examples; j++)
-    {
-      size_t flag = (size_t)  data(number_features, j);
-      if (i == flag)
-      {
-        ++number_of_occurrences;
-        for (size_t k = 0; k < number_features; k++)
-        {
-          double tmp = data(k, j);
-          feature_sum(k) += tmp;
-          feature_sum_squared(k) += tmp*tmp;
-        }
-      }
-    }
-
-    class_probabilities_[i] = (double) number_of_occurrences
-        / (double) number_examples;
-
-    for (size_t k = 0; k < number_features; k++)
-    {
-      double sum = feature_sum(k);
-      double sum_squared = feature_sum_squared(k);
-
-      means_(k, i) = (sum / number_of_occurrences);
-      variances_(k, i) = (sum_squared - (sum * sum / number_of_occurrences))
-          / (number_of_occurrences - 1);
-    }
-
-    // Reset the summations to zero for the next iteration
-    feature_sum.zeros(number_features);
-    feature_sum_squared.zeros(number_features);
-  }
-}
-
-void SimpleNaiveBayesClassifier::Classify(const arma::mat& test_data,
-                                          arma::vec& results)
-{
-  // Check that the number of features in the test data is same as in the
-  // training data.
-  Log::Assert(test_data.n_rows - 1 == means_.n_rows);
-
-  arma::vec tmp_vals(number_of_classes_);
-  size_t number_features = test_data.n_rows - 1;
-
-  results.zeros(test_data.n_cols);
-
-  Log::Info << test_data.n_cols << " test cases with " << number_features
-      << " features each" << std::endl;
-
-  // Calculate the joint probability for each of the data points for each of the
-  // classes.
-
-  // Loop over every test case.
-  for (size_t n = 0; n < test_data.n_cols; n++)
-  {
-    // Loop over every class.
-    for (size_t i = 0; i < number_of_classes_; i++)
-    {
-      // Use the log values to prevent floating point underflow.
-      tmp_vals(i) = log(class_probabilities_(i));
-
-      // Loop over every feature.
-      for (size_t j = 0; j < number_features; j++)
-      {
-        tmp_vals(i) += log(gmm::phi(test_data(j, n), means_(j, i),
-            variances_(j, i)));
-      }
-    }
-
-    // Find the index of the maximum value in tmp_vals.
-    size_t max = 0;
-    for (size_t k = 0; k < number_of_classes_; k++)
-    {
-      if (tmp_vals(max) < tmp_vals(k))
-        max = k;
-    }
-    results(n) = max;
-  }
-
-  return;
-}
-
-}; // namespace naive_bayes
-}; // namespace mlpack

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp	2011-12-13 19:08:34 UTC (rev 10754)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp	2011-12-13 19:34:59 UTC (rev 10755)
@@ -1,101 +0,0 @@
-/**
- * @file simple_nbc.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features.  It is assumed that the features have been sampled from a
- * Gaussian PDF.
- */
-#ifndef __MLPACK_METHODS_NBC_SIMPLE_NBC_HPP
-#define __MLPACK_METHODS_NBC_SIMPLE_NBC_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/gmm/phi.hpp>
-
-namespace mlpack {
-namespace naive_bayes {
-
-/**
- * A classification class. The class labels are assumed
- * to be positive integers - 0,1,2,....
- *
- * This class trains on the data by calculating the
- * sample mean and variance of the features with
- * respect to each of the labels, and also the class
- * probabilities.
- *
- * Mathematically, it computes P(X_i = x_i | Y = y_j)
- * for each feature X_i for each of the labels y_j.
- * Alongwith this, it also computes the classs probabilities
- * P( Y = y_j)
- *
- * For classifying a data point (x_1, x_2, ..., x_n),
- * it computes the following:
- * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
- *
- * Example use:
- *
- * @code
- * SimpleNaiveBayesClassifier nbc;
- * arma::mat training_data, testing_data;
- * datanode *nbc_module = fx_submodule(NULL,"nbc","nbc");
- * arma::vec results;
- *
- * nbc.InitTrain(training_data, nbc_module);
- * nbc.Classify(testing_data, &results);
- * @endcode
- */
-class SimpleNaiveBayesClassifier
-{
- public:
-  //! Sample mean for each class.
-  arma::mat means_;
-
-  //! Sample variances for each class.
-  arma::mat variances_;
-
-  //! Class probabilities.
-  arma::vec class_probabilities_;
-
-  //! The number of classes present.
-  size_t number_of_classes_;
-
-  /**
-   * Initializes the classifier as per the input and then trains it
-   * by calculating the sample mean and variances
-   *
-   * Example use:
-   * @code
-   * arma::mat training_data, testing_data;
-   * datanode nbc_module = fx_submodule(NULL,"nbc","nbc");
-   * ....
-   * SimpleNaiveBayesClassifier nbc(training_data, nbc_module);
-   * @endcode
-   */
-  SimpleNaiveBayesClassifier(const arma::mat& data, size_t classes);
-
-  /**
-   * Default constructor, you need to use the other one.
-   */
-  SimpleNaiveBayesClassifier();
-
-  ~SimpleNaiveBayesClassifier() { }
-
-  /**
-   * Given a bunch of data points, this function evaluates the class
-   * of each of those data points, and puts it in the vector 'results'
-   *
-   * @code
-   * arma::mat test_data; // each column is a test point
-   * arma::vec results;
-   * ...
-   * nbc.Classify(test_data, &results);
-   * @endcode
-   */
-  void Classify(const arma::mat& test_data, arma::vec& results);
-};
-
-}; // namespace naive_bayes
-}; // namespace mlpack
-
-#endif




More information about the mlpack-svn mailing list