[mlpack-svn] r10808 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 17:08:08 EST 2011


Author: rcurtin
Date: 2011-12-14 17:08:08 -0500 (Wed, 14 Dec 2011)
New Revision: 10808

Removed:
   mlpack/trunk/src/mlpack/methods/naive_bayes/README.txt
Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
Log:
Refactor and comment NaiveBayes a little better.


Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/README.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/README.txt	2011-12-14 22:07:06 UTC (rev 10807)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/README.txt	2011-12-14 22:08:08 UTC (rev 10808)
@@ -1,28 +0,0 @@
-Author : Parikshit Ram (pram at cc.gatech.edu)
-
-The files are the following:
-1. nbc_main.cc - this is the main which creates an object of the class SimpleNaiveBayesClassifier, trains it, tests it and outputs the results. The executable formed is called "nbc".
- - the parameters taken in by main are the following:
-   --train : the file that contains the training data, the last column being the class of the data point
-   --nbc/classes : the number of classes the data provided has been classified into
-   --test : this file contains the testing data, this still contains its actual labels on the last column, but it is not used.
-   --output : the file into which you want the output to be written into, defaults to "output.csv"
-
-2. simple_nbc.h - this is the file that contains the definition of the class SimpleNaiveBayesClassifier. The rest of the details are present in the file itself.
-
-3. phi.h - this contains the functions that calculate the value of the univariate and multivariate Gaussian probability density function
-
-4. test_simple_nbc_main.cc - this file contains the class which tests the class SimpleNaiveBayesClassifier. The executable formed is "nbc_test".
- - this takes no parameters
-
-5. the .arff files, whose use has been described above.
-
--> An example run would the following:
-fl-build nbc_main
-./nbc_main --train=trainSet.arff  --nbc/classes=2 --test=testSet.arff --output=output_example.csv
-
--> An example run of the testing class would be the following:
-fl-build test_simple_nbc_main
-./nbc_test
-
-Note: You don't need to give any parameters for testing because it will use the defaults. It requires the files be in the same directory as you run the executable from.

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	2011-12-14 22:07:06 UTC (rev 10807)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	2011-12-14 22:08:08 UTC (rev 10808)
@@ -13,36 +13,30 @@
 #include <mlpack/methods/gmm/phi.hpp>
 
 namespace mlpack {
-namespace naive_bayes {
+namespace naive_bayes /** The Naive Bayes Classifier. */ {
 
 /**
- * A classification class. The class labels are assumed
- * to be positive integers - 0,1,2,....
+ * The simple Naive Bayes classifier.  This class trains on the data by
+ * calculating the sample mean and variance of the features with respect to each
+ * of the labels, and also the class probabilities.  The class labels are
+ * assumed to be positive integers (starting with 0), and are expected to be the
+ * last row of the data input to the constructor.
  *
- * This class trains on the data by calculating the
- * sample mean and variance of the features with
- * respect to each of the labels, and also the class
- * probabilities.
+ * Mathematically, it computes P(X_i = x_i | Y = y_j) for each feature X_i for
+ * each of the labels y_j.  Alongwith this, it also computes the classs
+ * probabilities P(Y = y_j).
  *
- * Mathematically, it computes P(X_i = x_i | Y = y_j)
- * for each feature X_i for each of the labels y_j.
- * Alongwith this, it also computes the classs probabilities
- * P( Y = y_j)
- *
- * For classifying a data point (x_1, x_2, ..., x_n),
- * it computes the following:
+ * For classifying a data point (x_1, x_2, ..., x_n), it computes the following:
  * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
  *
  * Example use:
  *
  * @code
- * NaiveBayesClassifier nbc;
- * arma::mat training_data, testing_data;
- * datanode *nbc_module = fx_submodule(NULL,"nbc","nbc");
+ * extern arma::mat training_data, testing_data;
+ * NaiveBayesClassifier<> nbc(training_data, 5);
  * arma::vec results;
  *
- * nbc.InitTrain(training_data, nbc_module);
- * nbc.Classify(testing_data, &results);
+ * nbc.Classify(testing_data, results);
  * @endcode
  */
 template<typename MatType = arma::mat>
@@ -60,31 +54,35 @@
 
  public:
   /**
-   * Initializes the classifier as per the input and then trains it
-   * by calculating the sample mean and variances
+   * Initializes the classifier as per the input and then trains it by
+   * calculating the sample mean and variances.  The input data is expected to
+   * have integer labels as the last row (starting with 0 and not greater than
+   * the number of classes).
    *
    * Example use:
    * @code
-   * arma::mat training_data, testing_data;
-   * datanode nbc_module = fx_submodule(NULL,"nbc","nbc");
-   * ....
-   * NaiveBayesClassifier nbc(training_data, nbc_module);
+   * extern arma::mat training_data, testing_data;
+   * NaiveBayesClassifier nbc(training_data, 5);
    * @endcode
+   *
+   * @param data Sample data points; the last row should be labels.
+   * @param classes Number of classes in this classifier.
    */
   NaiveBayesClassifier(const MatType& data, const size_t classes);
 
-  ~NaiveBayesClassifier() { }
-
   /**
-   * Given a bunch of data points, this function evaluates the class
-   * of each of those data points, and puts it in the vector 'results'
+   * Given a bunch of data points, this function evaluates the class of each of
+   * those data points, and puts it in the vector 'results'.
    *
    * @code
    * arma::mat test_data; // each column is a test point
-   * arma::vec results;
+   * arma::Col<size_t> results;
    * ...
    * nbc.Classify(test_data, &results);
    * @endcode
+   *
+   * @param data List of data points.
+   * @param results Vector that class predictions will be placed into.
    */
   void Classify(const MatType& data, arma::Col<size_t>& results);
 

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	2011-12-14 22:07:06 UTC (rev 10807)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	2011-12-14 22:08:08 UTC (rev 10808)
@@ -6,71 +6,80 @@
  *
  * This classifier does parametric naive bayes classification assuming that the
  * features are sampled from a Gaussian distribution.
- *
- * PARAMETERS TO BE INPUT:
- *
- * --train
- * This is the file that contains the training data.
- *
- * --classes
- * This is the number of classes present in the training data.
- *
- * --test
- * This file contains the data points which the trained classifier should
- * classify.
- *
- * --output
- * This file will contain the classes to which the corresponding data points in
- * the testing data.
  */
 #include <mlpack/core.hpp>
 
 #include "naive_bayes_classifier.hpp"
 
-PARAM_INT_REQ("classes", "The number of classes present in the data.", "c")
+PROGRAM_INFO("Parametric Naive Bayes Classifier",
+    "This program trains the Naive Bayes classifier on the given labeled "
+    "training set and then uses the trained classifier to classify the points "
+    "in the given test set.\n"
+    "\n"
+    "Labels are expected to be the last row of the training set (--train_file),"
+    " but labels can also be passed in separately as their own file "
+    "(--labels_file).");
 
-PARAM_STRING_REQ("train", "A file containing the training set", "t");
-PARAM_STRING_REQ("test", "A file containing the test set", "T");
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
+
+PARAM_STRING("labels_file", "A file containing labels for the training set.",
+    "l", "");
 PARAM_STRING("output", "The file in which the output of the test would "
     "be written, defaults to 'output.csv')", "o", "output.csv");
 
-PROGRAM_INFO("Parametric Naive Bayes", "This program test drives the Parametric"
-    " Naive Bayes Classifier assuming that the features are sampled from a "
-    "Gaussian distribution.");
-
 using namespace mlpack;
-using namespace naive_bayes;
+using namespace mlpack::naive_bayes;
+using namespace std;
+using namespace arma;
 
 int main(int argc, char* argv[])
 {
   CLI::ParseCommandLine(argc, argv);
 
-  const char *training_data_filename =
-      CLI::GetParam<std::string>("train").c_str();
-  arma::mat training_data;
-  data::Load(training_data_filename, training_data, true);
+  // Check input parameters.
+  const string trainingDataFilename = CLI::GetParam<string>("train_file");
+  mat trainingData;
+  data::Load(trainingDataFilename.c_str(), trainingData, true);
 
-  const char *testing_data_filename =
-      CLI::GetParam<std::string>("test").c_str();
-  arma::mat testing_data;
-  data::Load(testing_data_filename, testing_data, true);
+  // Did the user pass in labels?
+  const string labelsFilename = CLI::GetParam<string>("labels_file");
+  if (labelsFilename != "")
+  {
+    // Load labels.
+    arma::mat labels;
+    data::Load(labelsFilename.c_str(), labels, true);
 
-  size_t number_of_classes_ = CLI::GetParam<size_t>("classes");
+    // Not incredibly efficient...
+    if (labels.n_rows == 1)
+      trainingData.insert_rows(trainingData.n_rows, trans(labels));
+    else if (labels.n_cols == 1)
+      trainingData.insert_rows(trainingData.n_rows, labels);
+    else
+      Log::Fatal << "Labels must have only one column or row!" << endl;
+  }
 
+  const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+  mat testingData;
+  data::Load(testingDataFilename.c_str(), testingData, true);
+
+  // Calculate number of classes.
+  size_t classes = (size_t) max(trainingData.row(trainingData.n_rows - 1));
+
   // Create and train the classifier.
   Timer::Start("training");
-  NaiveBayesClassifier<> nbc(training_data, number_of_classes_);
+  NaiveBayesClassifier<> nbc(trainingData, classes);
   Timer::Stop("training");
 
   // Timing the running of the Naive Bayes Classifier.
   arma::Col<size_t> results;
   Timer::Start("testing");
-  nbc.Classify(testing_data, results);
+  nbc.Classify(testingData, results);
   Timer::Stop("testing");
 
   // Output results.
-  std::string output_filename = CLI::GetParam<std::string>("output");
-  data::Save(output_filename.c_str(), results, true);
+  const string outputFilename = CLI::GetParam<string>("output");
+  data::Save(outputFilename.c_str(), results, true);
 
   return 0;
 }




More information about the mlpack-svn mailing list