[mlpack-git] master: Refactor nbc program to allow loading/saving models. (09cbc6e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 12:33:12 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/df229e45a5bd7842fe019e9d49ed32f13beb6aaa...09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7
>---------------------------------------------------------------
commit 09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7
Author: ryan <ryan at ratml.org>
Date: Mon Dec 21 12:32:49 2015 -0500
Refactor nbc program to allow loading/saving models.
>---------------------------------------------------------------
09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7
.../methods/naive_bayes/naive_bayes_classifier.hpp | 4 +-
src/mlpack/methods/naive_bayes/nbc_main.cpp | 192 ++++++++++++++-------
2 files changed, 133 insertions(+), 63 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 9647e88..c79d836 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -71,8 +71,8 @@ class NaiveBayesClassifier
* Train() before calling Classify(), otherwise the results may be
* meaningless.
*/
- NaiveBayesClassifier(const size_t dimensionality,
- const size_t classes);
+ NaiveBayesClassifier(const size_t dimensionality = 0,
+ const size_t classes = 0);
/**
* Train the Naive Bayes classifier on the given dataset. If the incremental
diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index c573205..0b14fba 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -24,87 +24,157 @@ PROGRAM_INFO("Parametric Naive Bayes Classifier",
"use an incremental algorithm for calculating variance. This is slower, "
"but can help avoid loss of precision in some cases.");
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
-
+// Model loading/saving.
+PARAM_STRING("input_model_file", "File containing input Naive Bayes model.",
+ "m", "");
+PARAM_STRING("output_model_file", "File to save trained Naive Bayes model to.",
+ "M", "");
+
+// Training parameters.
+PARAM_STRING("training_file", "A file containing the training set.", "t", "");
PARAM_STRING("labels_file", "A file containing labels for the training set.",
"l", "");
-PARAM_STRING("output_file", "The file in which the predicted labels for the "
- "test set will be written.", "o", "output.csv");
PARAM_FLAG("incremental_variance", "The variance of each class will be "
"calculated incrementally.", "I");
+// Test parameters.
+PARAM_STRING("test_file", "A file containing the test set.", "T", "");
+PARAM_STRING("output_file", "The file in which the predicted labels for the "
+ "test set will be written.", "o", "");
+
using namespace mlpack;
using namespace mlpack::naive_bayes;
using namespace std;
using namespace arma;
-int main(int argc, char* argv[])
+// A struct for saving the model with mappings.
+struct NBCModel
{
- CLI::ParseCommandLine(argc, argv);
-
- // Check input parameters.
- const string trainingDataFilename = CLI::GetParam<string>("train_file");
- mat trainingData;
- data::Load(trainingDataFilename, trainingData, true);
-
- // Normalize labels.
- Row<size_t> labels;
+ //! The model itself.
+ NaiveBayesClassifier<> nbc;
+ //! The mappings for labels.
Col<size_t> mappings;
- // Did the user pass in labels?
- const string labelsFilename = CLI::GetParam<string>("labels_file");
- if (labelsFilename != "")
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
{
- // Load labels.
- mat rawLabels;
- data::Load(labelsFilename, rawLabels, true, false);
+ ar & data::CreateNVP(nbc, "nbc");
+ ar & data::CreateNVP(mappings, "mappings");
+ }
+};
- // Do the labels need to be transposed?
- if (rawLabels.n_cols == 1)
- rawLabels = rawLabels.t();
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
- data::NormalizeLabels(rawLabels.row(0), labels, mappings);
+ // Check input parameters.
+ if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Cannot specify both --training_file (-t) and "
+ << "--input_model_file (-m)!" << endl;
+
+ if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) are"
+ << " specified!" << endl;
+
+ if (!CLI::HasParam("training_file") && CLI::HasParam("labels_file"))
+ Log::Warn << "--labels_file (-l) ignored because --training_file (-t) is "
+ << "not specified." << endl;
+ if (!CLI::HasParam("training_file") && CLI::HasParam("incremental_variance"))
+ Log::Warn << "--incremental_variance (-I) ignored because --training_file "
+ << "(-t) is not specified." << endl;
+
+ if (!CLI::HasParam("output_file") && !CLI::HasParam("output_model_file"))
+ Log::Warn << "Neither --output_file (-o) nor --output_model_file (-M) "
+ << "specified; no output will be saved!" << endl;
+
+ if (CLI::HasParam("output_file") && !CLI::HasParam("test_file"))
+ Log::Warn << "--output_file (-o) ignored because no test file specified "
+ << "with --test_file (-T)." << endl;
+
+ if (!CLI::HasParam("output_file") && CLI::HasParam("test_file"))
+ Log::Warn << "--test_file (-T) specified, but classification results will "
+ << "not be saved because --output_file (-o) is not specified." << endl;
+
+ // Either we have to train a model, or load a model.
+ NBCModel model;
+ if (CLI::HasParam("training_file"))
+ {
+ const string trainingFile = CLI::GetParam<string>("training_file");
+ mat trainingData;
+ data::Load(trainingFile, trainingData, true);
+
+ Row<size_t> labels;
+
+ // Did the user pass in labels?
+ const string labelsFilename = CLI::GetParam<string>("labels_file");
+ if (labelsFilename != "")
+ {
+ // Load labels.
+ mat rawLabels;
+ data::Load(labelsFilename, rawLabels, true, false);
+
+ // Do the labels need to be transposed?
+ if (rawLabels.n_cols == 1)
+ rawLabels = rawLabels.t();
+
+ data::NormalizeLabels(rawLabels.row(0), labels, model.mappings);
+ }
+ else
+ {
+ // Use the last row of the training data as the labels.
+ Log::Info << "Using last dimension of training data as training labels."
+ << endl;
+ data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
+ model.mappings);
+ // Remove the label row.
+ trainingData.shed_row(trainingData.n_rows - 1);
+ }
+
+ const bool incrementalVariance = CLI::HasParam("incremental_variance");
+
+ Timer::Start("nbc_training");
+ model.nbc = NaiveBayesClassifier<>(trainingData, labels,
+ model.mappings.n_elem, incrementalVariance);
+ Timer::Stop("nbc_training");
}
else
{
- // Use the last row of the training data as the labels.
- Log::Info << "Using last dimension of training data as training labels."
- << endl;
- data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
- mappings);
- // Remove the label row.
- trainingData.shed_row(trainingData.n_rows - 1);
+ // Load the model from file.
+ data::Load(CLI::GetParam<string>("input_model_file"), "nbc_model", model);
}
- const string testingDataFilename = CLI::GetParam<std::string>("test_file");
- mat testingData;
- data::Load(testingDataFilename, testingData, true);
-
- if (testingData.n_rows != trainingData.n_rows)
- Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
- << "must be the same as training data (" << trainingData.n_rows
- << ")!" << std::endl;
-
- const bool incrementalVariance = CLI::HasParam("incremental_variance");
-
- // Create and train the classifier.
- Timer::Start("training");
- NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
- incrementalVariance);
- Timer::Stop("training");
-
- // Time the running of the Naive Bayes Classifier.
- Row<size_t> results;
- Timer::Start("testing");
- nbc.Classify(testingData, results);
- Timer::Stop("testing");
-
- // Un-normalize labels to prepare output.
- Row<size_t> rawResults;
- data::RevertLabels(results, mappings, rawResults);
+ // Do we need to do testing?
+ if (CLI::HasParam("test_file"))
+ {
+ const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+ mat testingData;
+ data::Load(testingDataFilename, testingData, true);
+
+ if (testingData.n_rows != model.nbc.Means().n_rows)
+ Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+ << "must be the same as training data (" << model.nbc.Means().n_rows
+ << ")!" << std::endl;
+
+ // Time the running of the Naive Bayes Classifier.
+ Row<size_t> results;
+ Timer::Start("nbc_testing");
+ model.nbc.Classify(testingData, results);
+ Timer::Stop("nbc_testing");
+
+ if (CLI::HasParam("output_file"))
+ {
+ // Un-normalize labels to prepare output.
+ Row<size_t> rawResults;
+ data::RevertLabels(results, model.mappings, rawResults);
+
+ // Output results.
+ const string outputFilename = CLI::GetParam<string>("output_file");
+ data::Save(outputFilename, rawResults, true);
+ }
+ }
- // Output results.
- const string outputFilename = CLI::GetParam<string>("output_file");
- data::Save(outputFilename, rawResults, true);
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), "nbc_model", model,
+ false);
}
More information about the mlpack-git
mailing list