[mlpack-git] master: Refactor nbc program to allow loading/saving models. (09cbc6e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 12:33:12 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/df229e45a5bd7842fe019e9d49ed32f13beb6aaa...09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7

>---------------------------------------------------------------

commit 09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7
Author: ryan <ryan at ratml.org>
Date:   Mon Dec 21 12:32:49 2015 -0500

    Refactor nbc program to allow loading/saving models.


>---------------------------------------------------------------

09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7
 .../methods/naive_bayes/naive_bayes_classifier.hpp |   4 +-
 src/mlpack/methods/naive_bayes/nbc_main.cpp        | 192 ++++++++++++++-------
 2 files changed, 133 insertions(+), 63 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
index 9647e88..c79d836 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
@@ -71,8 +71,8 @@ class NaiveBayesClassifier
    * Train() before calling Classify(), otherwise the results may be
    * meaningless.
    */
-  NaiveBayesClassifier(const size_t dimensionality,
-                       const size_t classes);
+  NaiveBayesClassifier(const size_t dimensionality = 0,
+                       const size_t classes = 0);
 
   /**
    * Train the Naive Bayes classifier on the given dataset.  If the incremental
diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index c573205..0b14fba 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -24,87 +24,157 @@ PROGRAM_INFO("Parametric Naive Bayes Classifier",
     "use an incremental algorithm for calculating variance.  This is slower, "
     "but can help avoid loss of precision in some cases.");
 
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
-
+// Model loading/saving.
+PARAM_STRING("input_model_file", "File containing input Naive Bayes model.",
+    "m", "");
+PARAM_STRING("output_model_file", "File to save trained Naive Bayes model to.",
+    "M", "");
+
+// Training parameters.
+PARAM_STRING("training_file", "A file containing the training set.", "t", "");
 PARAM_STRING("labels_file", "A file containing labels for the training set.",
     "l", "");
-PARAM_STRING("output_file", "The file in which the predicted labels for the "
-    "test set will be written.", "o", "output.csv");
 PARAM_FLAG("incremental_variance", "The variance of each class will be "
     "calculated incrementally.", "I");
 
+// Test parameters.
+PARAM_STRING("test_file", "A file containing the test set.", "T", "");
+PARAM_STRING("output_file", "The file in which the predicted labels for the "
+    "test set will be written.", "o", "");
+
 using namespace mlpack;
 using namespace mlpack::naive_bayes;
 using namespace std;
 using namespace arma;
 
-int main(int argc, char* argv[])
+// A struct for saving the model with mappings.
+struct NBCModel
 {
-  CLI::ParseCommandLine(argc, argv);
-
-  // Check input parameters.
-  const string trainingDataFilename = CLI::GetParam<string>("train_file");
-  mat trainingData;
-  data::Load(trainingDataFilename, trainingData, true);
-
-  // Normalize labels.
-  Row<size_t> labels;
+  //! The model itself.
+  NaiveBayesClassifier<> nbc;
+  //! The mappings for labels.
   Col<size_t> mappings;
 
-  // Did the user pass in labels?
-  const string labelsFilename = CLI::GetParam<string>("labels_file");
-  if (labelsFilename != "")
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
   {
-    // Load labels.
-    mat rawLabels;
-    data::Load(labelsFilename, rawLabels, true, false);
+    ar & data::CreateNVP(nbc, "nbc");
+    ar & data::CreateNVP(mappings, "mappings");
+  }
+};
 
-    // Do the labels need to be transposed?
-    if (rawLabels.n_cols == 1)
-      rawLabels = rawLabels.t();
+int main(int argc, char* argv[])
+{
+  CLI::ParseCommandLine(argc, argv);
 
-    data::NormalizeLabels(rawLabels.row(0), labels, mappings);
+  // Check input parameters.
+  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Cannot specify both --training_file (-t) and "
+        << "--input_model_file (-m)!" << endl;
+
+  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) are"
+        << " specified!" << endl;
+
+  if (!CLI::HasParam("training_file") && CLI::HasParam("labels_file"))
+    Log::Warn << "--labels_file (-l) ignored because --training_file (-t) is "
+        << "not specified." << endl;
+  if (!CLI::HasParam("training_file") && CLI::HasParam("incremental_variance"))
+    Log::Warn << "--incremental_variance (-I) ignored because --training_file "
+        << "(-t) is not specified." << endl;
+
+  if (!CLI::HasParam("output_file") && !CLI::HasParam("output_model_file"))
+    Log::Warn << "Neither --output_file (-o) nor --output_model_file (-M) "
+        << "specified; no output will be saved!" << endl;
+
+  if (CLI::HasParam("output_file") && !CLI::HasParam("test_file"))
+    Log::Warn << "--output_file (-o) ignored because no test file specified "
+        << "with --test_file (-T)." << endl;
+
+  if (!CLI::HasParam("output_file") && CLI::HasParam("test_file"))
+    Log::Warn << "--test_file (-T) specified, but classification results will "
+        << "not be saved because --output_file (-o) is not specified." << endl;
+
+  // Either we have to train a model, or load a model.
+  NBCModel model;
+  if (CLI::HasParam("training_file"))
+  {
+    const string trainingFile = CLI::GetParam<string>("training_file");
+    mat trainingData;
+    data::Load(trainingFile, trainingData, true);
+
+    Row<size_t> labels;
+
+    // Did the user pass in labels?
+    const string labelsFilename = CLI::GetParam<string>("labels_file");
+    if (labelsFilename != "")
+    {
+      // Load labels.
+      mat rawLabels;
+      data::Load(labelsFilename, rawLabels, true, false);
+
+      // Do the labels need to be transposed?
+      if (rawLabels.n_cols == 1)
+        rawLabels = rawLabels.t();
+
+      data::NormalizeLabels(rawLabels.row(0), labels, model.mappings);
+    }
+    else
+    {
+      // Use the last row of the training data as the labels.
+      Log::Info << "Using last dimension of training data as training labels."
+          << endl;
+      data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
+          model.mappings);
+      // Remove the label row.
+      trainingData.shed_row(trainingData.n_rows - 1);
+    }
+
+    const bool incrementalVariance = CLI::HasParam("incremental_variance");
+
+    Timer::Start("nbc_training");
+    model.nbc = NaiveBayesClassifier<>(trainingData, labels,
+        model.mappings.n_elem, incrementalVariance);
+    Timer::Stop("nbc_training");
   }
   else
   {
-    // Use the last row of the training data as the labels.
-    Log::Info << "Using last dimension of training data as training labels."
-        << endl;
-    data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
-        mappings);
-    // Remove the label row.
-    trainingData.shed_row(trainingData.n_rows - 1);
+    // Load the model from file.
+    data::Load(CLI::GetParam<string>("input_model_file"), "nbc_model", model);
   }
 
-  const string testingDataFilename = CLI::GetParam<std::string>("test_file");
-  mat testingData;
-  data::Load(testingDataFilename, testingData, true);
-
-  if (testingData.n_rows != trainingData.n_rows)
-    Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
-        << "must be the same as training data (" << trainingData.n_rows
-        << ")!" << std::endl;
-
-  const bool incrementalVariance = CLI::HasParam("incremental_variance");
-
-  // Create and train the classifier.
-  Timer::Start("training");
-  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
-      incrementalVariance);
-  Timer::Stop("training");
-
-  // Time the running of the Naive Bayes Classifier.
-  Row<size_t> results;
-  Timer::Start("testing");
-  nbc.Classify(testingData, results);
-  Timer::Stop("testing");
-
-  // Un-normalize labels to prepare output.
-  Row<size_t> rawResults;
-  data::RevertLabels(results, mappings, rawResults);
+  // Do we need to do testing?
+  if (CLI::HasParam("test_file"))
+  {
+    const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+    mat testingData;
+    data::Load(testingDataFilename, testingData, true);
+
+    if (testingData.n_rows != model.nbc.Means().n_rows)
+      Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+          << "must be the same as training data (" << model.nbc.Means().n_rows
+          << ")!" << std::endl;
+
+    // Time the running of the Naive Bayes Classifier.
+    Row<size_t> results;
+    Timer::Start("nbc_testing");
+    model.nbc.Classify(testingData, results);
+    Timer::Stop("nbc_testing");
+
+    if (CLI::HasParam("output_file"))
+    {
+      // Un-normalize labels to prepare output.
+      Row<size_t> rawResults;
+      data::RevertLabels(results, model.mappings, rawResults);
+
+      // Output results.
+      const string outputFilename = CLI::GetParam<string>("output_file");
+      data::Save(outputFilename, rawResults, true);
+    }
+  }
 
-  // Output results.
-  const string outputFilename = CLI::GetParam<string>("output_file");
-  data::Save(outputFilename, rawResults, true);
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "nbc_model", model,
+        false);
 }



More information about the mlpack-git mailing list