[mlpack-svn] r15394 - mlpack/trunk/src/mlpack/methods/naive_bayes
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 3 13:51:05 EDT 2013
Author: rcurtin
Date: Wed Jul 3 13:51:05 2013
New Revision: 15394
Log:
Update NaiveBayesClassifier API to take labels as a separate vector. Also,
normalize the labels in the main executable.
Modified:
mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp Wed Jul 3 13:51:05 2013
@@ -62,13 +62,17 @@
* Example use:
* @code
* extern arma::mat training_data, testing_data;
- * NaiveBayesClassifier nbc(training_data, 5);
+ * extern arma::Col<size_t> labels;
+ * NaiveBayesClassifier nbc(training_data, labels, 5);
* @endcode
*
- * @param data Sample data points; the last row should be labels.
+ * @param data Training data points.
+ * @param labels Labels corresponding to training data points.
* @param classes Number of classes in this classifier.
*/
- NaiveBayesClassifier(const MatType& data, const size_t classes);
+ NaiveBayesClassifier(const MatType& data,
+ const arma::Col<size_t>& labels,
+ const size_t classes);
/**
* Given a bunch of data points, this function evaluates the class of each of
Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp Wed Jul 3 13:51:05 2013
@@ -18,10 +18,12 @@
namespace naive_bayes {
template<typename MatType>
-NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const MatType& data,
- const size_t classes)
+NaiveBayesClassifier<MatType>::NaiveBayesClassifier(
+ const MatType& data,
+ const arma::Col<size_t>& labels,
+ const size_t classes)
{
- size_t dimensionality = data.n_rows - 1;
+ size_t dimensionality = data.n_rows;
// Update the variables according to the number of features and classes
// present in the data.
@@ -36,11 +38,11 @@
// for each of the features with respect to each of the labels.
for (size_t j = 0; j < data.n_cols; ++j)
{
- size_t label = (size_t) data(dimensionality, j);
+ const size_t label = labels[j];
++probabilities[label];
- means.col(label) += data(arma::span(0, dimensionality - 1), j);
- variances.col(label) += square(data(arma::span(0, dimensionality - 1), j));
+ means.col(label) += data.col(j);
+ variances.col(label) += square(data.col(j));
}
for (size_t i = 0; i < classes; ++i)
Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp Wed Jul 3 13:51:05 2013
@@ -42,28 +42,36 @@
mat trainingData;
data::Load(trainingDataFilename.c_str(), trainingData, true);
+ // Normalize labels.
+ Col<size_t> labels;
+ vec mappings;
+
// Did the user pass in labels?
const string labelsFilename = CLI::GetParam<string>("labels_file");
if (labelsFilename != "")
{
// Load labels.
- arma::mat labels;
- data::Load(labelsFilename.c_str(), labels, true);
+ mat rawLabels;
+ data::Load(labelsFilename, rawLabels, true);
- // Not incredibly efficient...
- if (labels.n_rows == 1)
- trainingData.insert_rows(trainingData.n_rows, labels);
- else if (labels.n_cols == 1)
- trainingData.insert_rows(trainingData.n_rows, trans(labels));
- else
- Log::Fatal << "Labels must have only one column or row!" << endl;
+ data::NormalizeLabels(rawLabels.unsafe_col(0), labels, mappings);
+ }
+ else
+ {
+ // Use the last row of the training data as the labels.
+ Log::Info << "Using last dimension of training data as training labels."
+ << std::endl;
+ vec rawLabels = trans(trainingData.row(trainingData.n_rows - 1));
+ data::NormalizeLabels(rawLabels, labels, mappings);
+ // Remove the label row.
+ trainingData.shed_row(trainingData.n_rows - 1);
}
const string testingDataFilename = CLI::GetParam<std::string>("test_file");
mat testingData;
data::Load(testingDataFilename.c_str(), testingData, true);
- if (testingData.n_rows != trainingData.n_rows - 1)
+ if (testingData.n_rows != trainingData.n_rows)
Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
<< "must be the same as training data (" << trainingData.n_rows - 1
<< ")!" << std::endl;
@@ -73,18 +81,20 @@
// Create and train the classifier.
Timer::Start("training");
- NaiveBayesClassifier<> nbc(trainingData, classes);
+ NaiveBayesClassifier<> nbc(trainingData, labels, classes);
Timer::Stop("training");
- // Timing the running of the Naive Bayes Classifier.
- arma::Col<size_t> results;
+ // Time the running of the Naive Bayes Classifier.
+ Col<size_t> results;
Timer::Start("testing");
nbc.Classify(testingData, results);
Timer::Stop("testing");
+ // Un-normalize labels to prepare output.
+ vec rawResults;
+ data::RevertLabels(results, mappings, rawResults);
+
// Output results.
const string outputFilename = CLI::GetParam<string>("output");
- data::Save(outputFilename.c_str(), results, true);
-
- return 0;
+ data::Save(outputFilename, results, true);
}
More information about the mlpack-svn
mailing list