[mlpack-git] master: Refactor to add Train() and empty constructor. (4a469b2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:39 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit 4a469b2733c122870274ab813528182157b12211
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 25 22:53:20 2015 +0000

    Refactor to add Train() and empty constructor.


>---------------------------------------------------------------

4a469b2733c122870274ab813528182157b12211
 .../methods/naive_bayes/naive_bayes_classifier.hpp | 71 +++++++++++++++-----
 .../naive_bayes/naive_bayes_classifier_impl.hpp    | 78 ++++++++++++++++++----
 2 files changed, 120 insertions(+), 29 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 21813a2..51fdf4f 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -41,22 +41,10 @@ namespace naive_bayes /** The Naive Bayes Classifier. */ {
 template<typename MatType = arma::mat>
 class NaiveBayesClassifier
 {
- private:
-  //! Sample mean for each class.
-  MatType means;
-
-  //! Sample variances for each class.
-  MatType variances;
-
-  //! Class probabilities.
-  arma::vec probabilities;
-
  public:
   /**
    * Initializes the classifier as per the input and then trains it by
-   * calculating the sample mean and variances.  The input data is expected to
-   * have integer labels as the last row (starting with 0 and not greater than
-   * the number of classes).
+   * calculating the sample mean and variances.
    *
    * Example use:
    * @code
@@ -73,11 +61,50 @@ class NaiveBayesClassifier
    *     cases, but will be somewhat slower to calculate.
    */
   NaiveBayesClassifier(const MatType& data,
-                       const arma::Col<size_t>& labels,
+                       const arma::Row<size_t>& labels,
                        const size_t classes,
                        const bool incrementalVariance = false);
 
   /**
+   * Initialize the Naive Bayes classifier without performing training.  All of
+   * the parameters of the model will be initialized to zero.  Be sure to use
+   * Train() before calling Classify(), otherwise the results may be
+   * meaningless.
+   */
+  NaiveBayesClassifier(const size_t dimensionality,
+                       const size_t classes);
+
+  /**
+   * Train the Naive Bayes classifier on the given dataset.  If the incremental
+   * algorithm is used, the current model is used as a starting point (this is
+   * the default).  If the incremental algorithm is not used, then the current
+   * model is ignored and the new model will be trained only on the given data.
+   * Note that even if the incremental algorithm is not used, the data must have
+   * the same dimensionality and number of classes that the model was
+   * initialized with.  If you want to change the dimensionality or number of
+   * classes, either re-initialize or call Means(), Variances(), and
+   * Probabilities() individually to set them to the right size.
+   *
+   * @param data The dataset to train on.
+   * @param incremental Whether or not to use the incremental algorithm for
+   *      training.
+   */
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const bool incremental = true);
+
+  /**
+   * Train the Naive Bayes classifier on the given point.  This will use the
+   * incremental algorithm for updating the model parameters.  The data must be
+   * the same dimensionality as the existing model parameters.
+   *
+   * @param point Data point to train on.
+   * @param label Label of data point.
+   */
+  template<typename VecType>
+  void Train(const VecType& point, const size_t label);
+
+  /**
    * Given a bunch of data points, this function evaluates the class of each of
    * those data points, and puts it in the vector 'results'.
    *
@@ -91,7 +118,7 @@ class NaiveBayesClassifier
    * @param data List of data points.
    * @param results Vector that class predictions will be placed into.
    */
-  void Classify(const MatType& data, arma::Col<size_t>& results);
+  void Classify(const MatType& data, arma::Row<size_t>& results);
 
   //! Get the sample means for each class.
   const MatType& Means() const { return means; }
@@ -107,10 +134,20 @@ class NaiveBayesClassifier
   const arma::vec& Probabilities() const { return probabilities; }
   //! Modify the prior probabilities for each class.
   arma::vec& Probabilities() { return probabilities; }
+
+ private:
+  //! Sample mean for each class.
+  MatType means;
+  //! Sample variances for each class.
+  MatType variances;
+  //! Class probabilities.
+  arma::vec probabilities;
+  //! Number of training points seen so far.
+  size_t trainingPoints;
 };
 
-}; // namespace naive_bayes
-}; // namespace mlpack
+} // namespace naive_bayes
+} // namespace mlpack
 
 // Include implementation.
 #include "naive_bayes_classifier_impl.hpp"
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 5cd4fb9..9dbd7c5 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -22,26 +22,55 @@ namespace naive_bayes {
 template<typename MatType>
 NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
     const MatType& data,
-    const arma::Col<size_t>& labels,
+    const arma::Row<size_t>& labels,
     const size_t classes,
-    const bool incrementalVariance)
+    const bool incremental) :
+    trainingPoints(0) // Set when we call Train().
 {
   const size_t dimensionality = data.n_rows;
 
-  // Update the variables according to the number of features and classes
-  // present in the data.
+  // Perform training, after initializing the model to 0 (that is, if Train()
+  // won't do that for us, which it won't if we're using the incremental
+  // algorithm).
+  if (incremental)
+  {
+    probabilities.zeros(classes);
+    means.zeros(dimensionality, classes);
+    variances.zeros(dimensionality, classes);
+  }
+  else
+  {
+    probabilities.set_size(classes);
+    means.set_size(dimensionality, classes);
+    variances.set_size(dimensionality, classes);
+  }
+  Train(data, labels, incremental);
+}
+
+template<typename MatType>
+NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const size_t dimensionality,
+                                                    const size_t classes) :
+    trainingPoints(0)
+{
+  // Initialize model to 0.
   probabilities.zeros(classes);
   means.zeros(dimensionality, classes);
   variances.zeros(dimensionality, classes);
+}
 
-  Log::Info << "Training Naive Bayes classifier on " << data.n_cols
-      << " examples with " << dimensionality << " features each." << std::endl;
-
+template<typename MatType>
+void NaiveBayesClassifier<MatType>::Train(const MatType& data,
+                                          const arma::Row<size_t>& labels,
+                                          const bool incremental)
+{
   // Calculate the class probabilities as well as the sample mean and variance
   // for each of the features with respect to each of the labels.
-  if (incrementalVariance)
+  if (incremental)
   {
     // Use incremental algorithm.
+    // Fist, de-normalize probabilities.
+    probabilities *= trainingPoints;
+
     for (size_t j = 0; j < data.n_cols; ++j)
     {
       const size_t label = labels[j];
@@ -52,7 +81,7 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
       variances.col(label) += delta % (data.col(j) - means.col(label));
     }
 
-    for (size_t i = 0; i < classes; ++i)
+    for (size_t i = 0; i < probabilities.n_elem; ++i)
     {
       if (probabilities[i] > 2)
         variances.col(i) /= (probabilities[i] - 1);
@@ -60,6 +89,11 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
   }
   else
   {
+    // Set all parameters to zero
+    probabilities.zeros();
+    means.zeros();
+    variances.zeros();
+
     // Don't use incremental algorithm.  This is a two-pass algorithm.  It is
     // possible to calculate the means and variances using a faster one-pass
     // algorithm but there are some precision and stability issues.  If this is
@@ -75,7 +109,7 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
     }
 
     // Normalize means.
-    for (size_t i = 0; i < classes; ++i)
+    for (size_t i = 0; i < probabilities.n_elem; ++i)
       if (probabilities[i] != 0.0)
         means.col(i) /= probabilities[i];
 
@@ -87,7 +121,7 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
     }
 
     // Normalize variances.
-    for (size_t i = 0; i < classes; ++i)
+    for (size_t i = 0; i < probabilities.n_elem; ++i)
       if (probabilities[i] > 1)
         variances.col(i) /= (probabilities[i] - 1);
   }
@@ -98,11 +132,31 @@ NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
       variances[i] = 1e-50;
 
   probabilities /= data.n_cols;
+  trainingPoints += data.n_cols;
+}
+
+template<typename MatType>
+template<typename VecType>
+void NaiveBayesClassifier<MatType>::Train(const VecType& point,
+                                          const size_t label)
+{
+  // We must use the incremental algorithm here.
+  probabilities *= trainingPoints;
+  probabilities[label]++;
+
+  arma::vec delta = point - means.col(label);
+  means.col(label) += delta / probabilities[label];
+  variances.col(label) *= (probabilities[label] - 1) / probabilities[label];
+  variances.col(label) += (1 / probabilities[label]) *
+      (delta % (point - means.col(label)));
+
+  trainingPoints++;
+  probabilities /= trainingPoints;
 }
 
 template<typename MatType>
 void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
-                                             arma::Col<size_t>& results)
+                                             arma::Row<size_t>& results)
 {
   // Check that the number of features in the test data is same as in the
   // training data.



More information about the mlpack-git mailing list