[mlpack-svn] r15394 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 3 13:51:05 EDT 2013


Author: rcurtin
Date: Wed Jul  3 13:51:05 2013
New Revision: 15394

Log:
Update NaiveBayesClassifier API to take labels as a separate vector.  Also,
normalize the labels in the main executable.


Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp	Wed Jul  3 13:51:05 2013
@@ -62,13 +62,17 @@
    * Example use:
    * @code
    * extern arma::mat training_data, testing_data;
-   * NaiveBayesClassifier nbc(training_data, 5);
+   * extern arma::Col<size_t> labels;
+   * NaiveBayesClassifier nbc(training_data, labels, 5);
    * @endcode
    *
-   * @param data Sample data points; the last row should be labels.
+   * @param data Training data points.
+   * @param labels Labels corresponding to training data points.
    * @param classes Number of classes in this classifier.
    */
-  NaiveBayesClassifier(const MatType& data, const size_t classes);
+  NaiveBayesClassifier(const MatType& data,
+                       const arma::Col<size_t>& labels,
+                       const size_t classes);
 
   /**
    * Given a bunch of data points, this function evaluates the class of each of

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp	Wed Jul  3 13:51:05 2013
@@ -18,10 +18,12 @@
 namespace naive_bayes {
 
 template<typename MatType>
-NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const MatType& data,
-                                                    const size_t classes)
+NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
+    const MatType& data,
+    const arma::Col<size_t>& labels,
+    const size_t classes)
 {
-  size_t dimensionality = data.n_rows - 1;
+  size_t dimensionality = data.n_rows;
 
   // Update the variables according to the number of features and classes
   // present in the data.
@@ -36,11 +38,11 @@
   // for each of the features with respect to each of the labels.
   for (size_t j = 0; j < data.n_cols; ++j)
   {
-    size_t label = (size_t) data(dimensionality, j);
+    const size_t label = labels[j];
     ++probabilities[label];
 
-    means.col(label) += data(arma::span(0, dimensionality - 1), j);
-    variances.col(label) += square(data(arma::span(0, dimensionality - 1), j));
+    means.col(label) += data.col(j);
+    variances.col(label) += square(data.col(j));
   }
 
   for (size_t i = 0; i < classes; ++i)

Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	Wed Jul  3 13:51:05 2013
@@ -42,28 +42,36 @@
   mat trainingData;
   data::Load(trainingDataFilename.c_str(), trainingData, true);
 
+  // Normalize labels.
+  Col<size_t> labels;
+  vec mappings;
+
   // Did the user pass in labels?
   const string labelsFilename = CLI::GetParam<string>("labels_file");
   if (labelsFilename != "")
   {
     // Load labels.
-    arma::mat labels;
-    data::Load(labelsFilename.c_str(), labels, true);
+    mat rawLabels;
+    data::Load(labelsFilename, rawLabels, true);
 
-    // Not incredibly efficient...
-    if (labels.n_rows == 1)
-      trainingData.insert_rows(trainingData.n_rows, labels);
-    else if (labels.n_cols == 1)
-      trainingData.insert_rows(trainingData.n_rows, trans(labels));
-    else
-      Log::Fatal << "Labels must have only one column or row!" << endl;
+    data::NormalizeLabels(rawLabels.unsafe_col(0), labels, mappings);
+  }
+  else
+  {
+    // Use the last row of the training data as the labels.
+    Log::Info << "Using last dimension of training data as training labels."
+        << std::endl;
+    vec rawLabels = trans(trainingData.row(trainingData.n_rows - 1));
+    data::NormalizeLabels(rawLabels, labels, mappings);
+    // Remove the label row.
+    trainingData.shed_row(trainingData.n_rows - 1);
   }
 
   const string testingDataFilename = CLI::GetParam<std::string>("test_file");
   mat testingData;
   data::Load(testingDataFilename.c_str(), testingData, true);
 
-  if (testingData.n_rows != trainingData.n_rows - 1)
+  if (testingData.n_rows != trainingData.n_rows)
     Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
         << "must be the same as training data (" << trainingData.n_rows - 1
         << ")!" << std::endl;
@@ -73,18 +81,20 @@
 
   // Create and train the classifier.
   Timer::Start("training");
-  NaiveBayesClassifier<> nbc(trainingData, classes);
+  NaiveBayesClassifier<> nbc(trainingData, labels, classes);
   Timer::Stop("training");
 
-  // Timing the running of the Naive Bayes Classifier.
-  arma::Col<size_t> results;
+  // Time the running of the Naive Bayes Classifier.
+  Col<size_t> results;
   Timer::Start("testing");
   nbc.Classify(testingData, results);
   Timer::Stop("testing");
 
+  // Un-normalize labels to prepare output.
+  vec rawResults;
+  data::RevertLabels(results, mappings, rawResults);
+
   // Output results.
   const string outputFilename = CLI::GetParam<string>("output");
-  data::Save(outputFilename.c_str(), results, true);
-
-  return 0;
+  data::Save(outputFilename, results, true);
 }



More information about the mlpack-svn mailing list