[mlpack-git] master: Refactor for new NBC API. Not yet tested. (e59d52d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:43 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3
>---------------------------------------------------------------
commit e59d52d95af74990d45d8238e3fa76eeed8d542b
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 25 23:35:28 2015 +0000
Refactor for new NBC API. Not yet tested.
>---------------------------------------------------------------
e59d52d95af74990d45d8238e3fa76eeed8d542b
src/mlpack/methods/naive_bayes/nbc_main.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/naive_bayes/nbc_main.cpp b/src/mlpack/methods/naive_bayes/nbc_main.cpp
index 41c7533..eac078f 100644
--- a/src/mlpack/methods/naive_bayes/nbc_main.cpp
+++ b/src/mlpack/methods/naive_bayes/nbc_main.cpp
@@ -50,6 +50,7 @@ int main(int argc, char* argv[])
// Normalize labels.
Col<size_t> labels;
+ Row<size_t> labelst;
vec mappings;
// Did the user pass in labels?
@@ -76,6 +77,7 @@ int main(int argc, char* argv[])
// Remove the label row.
trainingData.shed_row(trainingData.n_rows - 1);
}
+ labelst = labels.t();
const string testingDataFilename = CLI::GetParam<std::string>("test_file");
mat testingData;
@@ -90,21 +92,21 @@ int main(int argc, char* argv[])
// Create and train the classifier.
Timer::Start("training");
- NaiveBayesClassifier<> nbc(trainingData, labels, mappings.n_elem,
+ NaiveBayesClassifier<> nbc(trainingData, labelst, mappings.n_elem,
incrementalVariance);
Timer::Stop("training");
// Time the running of the Naive Bayes Classifier.
- Col<size_t> results;
+ Row<size_t> results;
Timer::Start("testing");
nbc.Classify(testingData, results);
Timer::Stop("testing");
// Un-normalize labels to prepare output.
vec rawResults;
- data::RevertLabels(results, mappings, rawResults);
+ data::RevertLabels(results.t(), mappings, rawResults);
- // Output results. Don't transpose: one result per line.
+ // Output results. Transpose: one result per line.
const string outputFilename = CLI::GetParam<string>("output");
data::Save(outputFilename, rawResults, true, false);
}
More information about the mlpack-git
mailing list