[mlpack-svn] r10763 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 03:19:15 EST 2011
Author: rcurtin
Date: 2011-12-14 03:19:14 -0500 (Wed, 14 Dec 2011)
New Revision: 10763
Modified:
mlpack/trunk/src/mlpack/tests/nbc_test.cpp
Log:
Restructure test.
Modified: mlpack/trunk/src/mlpack/tests/nbc_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/nbc_test.cpp 2011-12-14 08:04:33 UTC (rev 10762)
+++ mlpack/trunk/src/mlpack/tests/nbc_test.cpp 2011-12-14 08:19:14 UTC (rev 10763)
@@ -4,7 +4,7 @@
* Test for the Naive Bayes classifier.
*/
#include <mlpack/core.hpp>
-#include <mlpack/methods/naive_bayes/simple_nbc.hpp>
+#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
#include <boost/test/unit_test.hpp>
@@ -13,53 +13,51 @@
BOOST_AUTO_TEST_SUITE(NBCTest);
-BOOST_AUTO_TEST_CASE(SimpleNBCTest)
+BOOST_AUTO_TEST_CASE(NaiveBayesClassifierTest)
{
- const char* filename_train_ = "trainSet.csv";
- const char* filename_test_ = "testSet.csv";
- const char* train_result_ = "trainRes.csv";
- const char* test_result_ = "testRes.csv";
- size_t number_of_classes_ = 2;
+ const char* trainFilename = "trainSet.csv";
+ const char* testFilename = "testSet.csv";
+ const char* trainResultFilename = "trainRes.csv";
+ const char* testResultFilename = "testRes.csv";
+ size_t classes = 2;
- arma::mat train_data, train_res, calc_mat;
- data::Load(filename_train_, train_data, true);
- data::Load(train_result_, train_res, true);
+ arma::mat trainData, trainRes, calcMat;
+ data::Load(trainFilename, trainData, true);
+ data::Load(trainResultFilename, trainRes, true);
- //Does not actually grab number_of_classes from the command line, as this
- //is a boost unit test.
- SimpleNaiveBayesClassifier nbc_test_(train_data, number_of_classes_);
+ NaiveBayesClassifier<> nbcTest(trainData, classes);
- size_t number_of_features = nbc_test_.means_.n_rows;
- calc_mat.zeros(2 * number_of_features + 1, number_of_classes_);
+ size_t dimension = nbcTest.Means().n_rows;
+ calcMat.zeros(2 * dimension + 1, classes);
- for (size_t i = 0; i < number_of_features; i++)
+ for (size_t i = 0; i < dimension; i++)
{
- for (size_t j = 0; j < number_of_classes_; j++)
+ for (size_t j = 0; j < classes; j++)
{
- calc_mat(i, j) = nbc_test_.means_(i, j);
- calc_mat(i + number_of_features, j) = nbc_test_.variances_(i, j);
+ calcMat(i, j) = nbcTest.Means()(i, j);
+ calcMat(i + dimension, j) = nbcTest.Variances()(i, j);
}
}
- for (size_t i = 0; i < number_of_classes_; i++)
- calc_mat(2 * number_of_features, i) = nbc_test_.class_probabilities_(i);
+ for (size_t i = 0; i < classes; i++)
+ calcMat(2 * dimension, i) = nbcTest.Probabilities()(i);
- for(size_t i = 0; i < calc_mat.n_rows; i++)
- for(size_t j = 0; j < number_of_classes_; j++)
- BOOST_REQUIRE_CLOSE(train_res(i, j) + .00001, calc_mat(i, j), .01);
+ for(size_t i = 0; i < calcMat.n_rows; i++)
+ for(size_t j = 0; j < classes; j++)
+ BOOST_REQUIRE_CLOSE(trainRes(i, j) + .00001, calcMat(i, j), .01);
- arma::mat test_data, test_res;
- arma::vec test_res_vec, calc_vec;
- data::Load(filename_test_, test_data, true);
- data::Load(test_result_, test_res, true);
+ arma::mat testData;
+ arma::Mat<size_t> testRes;
+ arma::Col<size_t> calcVec;
+ data::Load(testFilename, testData, true);
+ data::Load(testResultFilename, testRes, true);
- nbc_test_.Classify(test_data, calc_vec);
+ testData.shed_row(testData.n_rows - 1); // Remove the labels.
- size_t number_of_datum = test_data.n_cols;
- test_res_vec = test_res.col(0);
+ nbcTest.Classify(testData, calcVec);
- for(size_t i = 0; i < number_of_datum; i++)
- BOOST_REQUIRE_EQUAL(test_res_vec(i), calc_vec(i));
+ for(size_t i = 0; i < testData.n_cols; i++)
+ BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list