[mlpack-git] master: Flesh out tests for DatasetInfo Load() functions. (ae7f38a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:47 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit ae7f38a1fd690cc649006010db89f2d38d37b8f0
Author: ryan <ryan at ratml.org>
Date: Thu Sep 10 17:08:20 2015 -0400
Flesh out tests for DatasetInfo Load() functions.
>---------------------------------------------------------------
ae7f38a1fd690cc649006010db89f2d38d37b8f0
src/mlpack/tests/load_save_test.cpp | 151 +++++++++++++++++++++++++++++-------
1 file changed, 122 insertions(+), 29 deletions(-)
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 4047170..9a97b72 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -865,11 +865,11 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
{
vector<string> testFiles;
- //testFiles.push_back("fake.csv");
- //testFiles.push_back("german.csv");
- testFiles.push_back("iris.txt");
- testFiles.push_back("vc2.txt");
- testFiles.push_back("johnson-8-4-4.csv");
+ testFiles.push_back("fake.csv");
+ testFiles.push_back("german.csv");
+ testFiles.push_back("iris.csv");
+ testFiles.push_back("vc2.csv");
+ testFiles.push_back("johnson8-4-4.csv");
testFiles.push_back("lars_dependent_y.csv");
testFiles.push_back("vc2_test_labels.txt");
@@ -886,8 +886,6 @@ BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
for (size_t i = 0; i < one.n_elem; ++i)
{
- std::cout << "i " << i << ": one " << one[i] << " two " << two[i] <<
-".\n";
if (std::abs(one[i]) < 1e-8)
BOOST_REQUIRE_SMALL(two[i], 1e-8);
else
@@ -909,9 +907,9 @@ BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
vector<string> testFiles;
testFiles.push_back("fake.csv");
testFiles.push_back("german.csv");
- testFiles.push_back("iris.txt");
- testFiles.push_back("vc2.txt");
- testFiles.push_back("johnson-8-4-4.csv");
+ testFiles.push_back("iris.csv");
+ testFiles.push_back("vc2.csv");
+ testFiles.push_back("johnson8-4-4.csv");
testFiles.push_back("lars_dependent_y.csv");
testFiles.push_back("vc2_test_labels.txt");
@@ -945,7 +943,8 @@ BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
*/
BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
{
- fstream f("test.csv");
+ fstream f;
+ f.open("test.csv", fstream::out);
f << "1, 2, hello" << endl;
f << "3, 4, goodbye" << endl;
f << "5, 6, coffee" << endl;
@@ -964,26 +963,26 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
BOOST_REQUIRE_EQUAL(matrix.n_rows, 3);
BOOST_REQUIRE_EQUAL(matrix(0, 0), 1);
- BOOST_REQUIRE_EQUAL(matrix(0, 1), 2);
- BOOST_REQUIRE_EQUAL(matrix(0, 2), 0);
- BOOST_REQUIRE_EQUAL(matrix(1, 0), 3);
+ BOOST_REQUIRE_EQUAL(matrix(1, 0), 2);
+ BOOST_REQUIRE_EQUAL(matrix(2, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(0, 1), 3);
BOOST_REQUIRE_EQUAL(matrix(1, 1), 4);
- BOOST_REQUIRE_EQUAL(matrix(1, 2), 1);
- BOOST_REQUIRE_EQUAL(matrix(2, 0), 5);
- BOOST_REQUIRE_EQUAL(matrix(2, 1), 6);
+ BOOST_REQUIRE_EQUAL(matrix(2, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(0, 2), 5);
+ BOOST_REQUIRE_EQUAL(matrix(1, 2), 6);
BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
- BOOST_REQUIRE_EQUAL(matrix(3, 0), 7);
- BOOST_REQUIRE_EQUAL(matrix(3, 1), 8);
- BOOST_REQUIRE_EQUAL(matrix(3, 2), 3);
- BOOST_REQUIRE_EQUAL(matrix(4, 0), 9);
- BOOST_REQUIRE_EQUAL(matrix(4, 1), 10);
- BOOST_REQUIRE_EQUAL(matrix(4, 2), 0);
- BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
- BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
- BOOST_REQUIRE_EQUAL(matrix(5, 2), 3);
- BOOST_REQUIRE_EQUAL(matrix(6, 0), 13);
- BOOST_REQUIRE_EQUAL(matrix(6, 1), 14);
- BOOST_REQUIRE_EQUAL(matrix(6, 2), 3);
+ BOOST_REQUIRE_EQUAL(matrix(0, 3), 7);
+ BOOST_REQUIRE_EQUAL(matrix(1, 3), 8);
+ BOOST_REQUIRE_EQUAL(matrix(2, 3), 3);
+ BOOST_REQUIRE_EQUAL(matrix(0, 4), 9);
+ BOOST_REQUIRE_EQUAL(matrix(1, 4), 10);
+ BOOST_REQUIRE_EQUAL(matrix(2, 4), 0);
+ BOOST_REQUIRE_EQUAL(matrix(0, 5), 11);
+ BOOST_REQUIRE_EQUAL(matrix(1, 5), 12);
+ BOOST_REQUIRE_EQUAL(matrix(2, 5), 3);
+ BOOST_REQUIRE_EQUAL(matrix(0, 6), 13);
+ BOOST_REQUIRE_EQUAL(matrix(1, 6), 14);
+ BOOST_REQUIRE_EQUAL(matrix(2, 6), 3);
BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::numeric);
BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::numeric);
@@ -998,6 +997,100 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
BOOST_REQUIRE_EQUAL(info.UnmapString(3, 2), "confusion");
+
+ remove("test.csv");
+}
+
+BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest)
+{
+ fstream f;
+ f.open("test.csv", fstream::out);
+ f << "1, 2, hello" << endl;
+ f << "3, 4, goodbye" << endl;
+ f << "5, 6, coffee" << endl;
+ f << "7, 8, confusion" << endl;
+ f << "9, 10, hello" << endl;
+ f << "11, 12, 15" << endl;
+ f << "13, 14, confusion" << endl;
+ f.close();
+
+ // Load the test CSV.
+ arma::umat matrix;
+ DatasetInfo info;
+ data::Load("test.csv", matrix, info, true, false); // No transpose.
+
+ BOOST_REQUIRE_EQUAL(matrix.n_cols, 3);
+ BOOST_REQUIRE_EQUAL(matrix.n_rows, 7);
+
+ BOOST_REQUIRE_EQUAL(matrix(0, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(0, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(0, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(1, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(1, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(1, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(2, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(2, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(3, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(3, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(3, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(4, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(4, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(4, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
+ BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
+ BOOST_REQUIRE_EQUAL(matrix(5, 2), 15);
+ BOOST_REQUIRE_EQUAL(matrix(6, 0), 0);
+ BOOST_REQUIRE_EQUAL(matrix(6, 1), 1);
+ BOOST_REQUIRE_EQUAL(matrix(6, 2), 2);
+
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::categorical);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::categorical);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(2), Datatype::categorical);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(3), Datatype::categorical);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(4), Datatype::categorical);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(5), Datatype::numeric);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(6), Datatype::categorical);
+
+ BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("2", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("hello", 0), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("3", 1), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("4", 1), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 1), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("5", 2), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("6", 2), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("7", 3), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("8", 3), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("confusion", 3), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("9", 4), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("10", 4), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("hello", 4), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("13", 6), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("14", 6), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("confusion", 6), 2);
+
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "2");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 0), "hello");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "3");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "4");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 1), "goodbye");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "5");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "6");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 3), "7");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 3), "8");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 3), "confusion");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 4), "9");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 4), "10");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 4), "hello");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 6), "13");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 6), "14");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 6), "confusion");
+
+ remove("test.csv");
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list