[mlpack-git] master: Refactor for new NBC API. Not yet tested. (e59d52d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:43 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit e59d52d95af74990d45d8238e3fa76eeed8d542b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 25 23:35:28 2015 +0000

    Refactor for new NBC API. Not yet tested.


>---------------------------------------------------------------

e59d52d95af74990d45d8238e3fa76eeed8d542b
 src/mlpack/methods/naive_bayes/nbc_main.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index 41c7533..eac078f 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -50,6 +50,7 @@ int main(int argc, char* argv[])
 
   // Normalize labels.
   Col<size_t> labels;
+  Row<size_t> labelst;
   vec mappings;
 
   // Did the user pass in labels?
@@ -76,6 +77,7 @@ int main(int argc, char* argv[])
     // Remove the label row.
     trainingData.shed_row(trainingData.n_rows - 1);
   }
+  labelst = labels.t();
 
   const string testingDataFilename = CLI::GetParam<std::string>("test_file");
   mat testingData;
@@ -90,21 +92,21 @@ int main(int argc, char* argv[])
 
   // Create and train the classifier.
   Timer::Start("training");
-  NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
+  NaiveBayesClassifier<> nbc(trainingData, labelst, mappings.n_elem,
       incrementalVariance);
   Timer::Stop("training");
 
   // Time the running of the Naive Bayes Classifier.
-  Col<size_t> results;
+  Row<size_t> results;
   Timer::Start("testing");
   nbc.Classify(testingData, results);
   Timer::Stop("testing");
 
   // Un-normalize labels to prepare output.
   vec rawResults;
-  data::RevertLabels(results, mappings, rawResults);
+  data::RevertLabels(results.t(), mappings, rawResults);
 
-  // Output results.  Don't transpose: one result per line.
+  // Output results.  Transpose: one result per line.
   const string outputFilename = CLI::GetParam<string>("output");
   data::Save(outputFilename, rawResults, true, false);
 }



More information about the mlpack-git mailing list