[mlpack-git] master: Flesh out tests for DatasetInfo Load() functions. (ae7f38a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:47 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit ae7f38a1fd690cc649006010db89f2d38d37b8f0
Author: ryan <ryan at ratml.org>
Date:   Thu Sep 10 17:08:20 2015 -0400

    Flesh out tests for DatasetInfo Load() functions.


>---------------------------------------------------------------

ae7f38a1fd690cc649006010db89f2d38d37b8f0
 src/mlpack/tests/load_save_test.cpp | 151 +++++++++++++++++++++++++++++-------
 1 file changed, 122 insertions(+), 29 deletions(-)

diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 4047170..9a97b72 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -865,11 +865,11 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
 BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
 {
   vector<string> testFiles;
-  //testFiles.push_back("fake.csv");
-  //testFiles.push_back("german.csv");
-  testFiles.push_back("iris.txt");
-  testFiles.push_back("vc2.txt");
-  testFiles.push_back("johnson-8-4-4.csv");
+  testFiles.push_back("fake.csv");
+  testFiles.push_back("german.csv");
+  testFiles.push_back("iris.csv");
+  testFiles.push_back("vc2.csv");
+  testFiles.push_back("johnson8-4-4.csv");
   testFiles.push_back("lars_dependent_y.csv");
   testFiles.push_back("vc2_test_labels.txt");
 
@@ -886,8 +886,6 @@ BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
     BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
     for (size_t i = 0; i < one.n_elem; ++i)
     {
-      std::cout << "i " << i << ": one " << one[i] << " two " << two[i] <<
-".\n";
       if (std::abs(one[i]) < 1e-8)
         BOOST_REQUIRE_SMALL(two[i], 1e-8);
       else
@@ -909,9 +907,9 @@ BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
   vector<string> testFiles;
   testFiles.push_back("fake.csv");
   testFiles.push_back("german.csv");
-  testFiles.push_back("iris.txt");
-  testFiles.push_back("vc2.txt");
-  testFiles.push_back("johnson-8-4-4.csv");
+  testFiles.push_back("iris.csv");
+  testFiles.push_back("vc2.csv");
+  testFiles.push_back("johnson8-4-4.csv");
   testFiles.push_back("lars_dependent_y.csv");
   testFiles.push_back("vc2_test_labels.txt");
 
@@ -945,7 +943,8 @@ BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
  */
 BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
 {
-  fstream f("test.csv");
+  fstream f;
+  f.open("test.csv", fstream::out);
   f << "1, 2, hello" << endl;
   f << "3, 4, goodbye" << endl;
   f << "5, 6, coffee" << endl;
@@ -964,26 +963,26 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
   BOOST_REQUIRE_EQUAL(matrix.n_rows, 3);
 
   BOOST_REQUIRE_EQUAL(matrix(0, 0), 1);
-  BOOST_REQUIRE_EQUAL(matrix(0, 1), 2);
-  BOOST_REQUIRE_EQUAL(matrix(0, 2), 0);
-  BOOST_REQUIRE_EQUAL(matrix(1, 0), 3);
+  BOOST_REQUIRE_EQUAL(matrix(1, 0), 2);
+  BOOST_REQUIRE_EQUAL(matrix(2, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(0, 1), 3);
   BOOST_REQUIRE_EQUAL(matrix(1, 1), 4);
-  BOOST_REQUIRE_EQUAL(matrix(1, 2), 1);
-  BOOST_REQUIRE_EQUAL(matrix(2, 0), 5);
-  BOOST_REQUIRE_EQUAL(matrix(2, 1), 6);
+  BOOST_REQUIRE_EQUAL(matrix(2, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(0, 2), 5);
+  BOOST_REQUIRE_EQUAL(matrix(1, 2), 6);
   BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
-  BOOST_REQUIRE_EQUAL(matrix(3, 0), 7);
-  BOOST_REQUIRE_EQUAL(matrix(3, 1), 8);
-  BOOST_REQUIRE_EQUAL(matrix(3, 2), 3);
-  BOOST_REQUIRE_EQUAL(matrix(4, 0), 9);
-  BOOST_REQUIRE_EQUAL(matrix(4, 1), 10);
-  BOOST_REQUIRE_EQUAL(matrix(4, 2), 0);
-  BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
-  BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
-  BOOST_REQUIRE_EQUAL(matrix(5, 2), 3);
-  BOOST_REQUIRE_EQUAL(matrix(6, 0), 13);
-  BOOST_REQUIRE_EQUAL(matrix(6, 1), 14);
-  BOOST_REQUIRE_EQUAL(matrix(6, 2), 3);
+  BOOST_REQUIRE_EQUAL(matrix(0, 3), 7);
+  BOOST_REQUIRE_EQUAL(matrix(1, 3), 8);
+  BOOST_REQUIRE_EQUAL(matrix(2, 3), 3);
+  BOOST_REQUIRE_EQUAL(matrix(0, 4), 9);
+  BOOST_REQUIRE_EQUAL(matrix(1, 4), 10);
+  BOOST_REQUIRE_EQUAL(matrix(2, 4), 0);
+  BOOST_REQUIRE_EQUAL(matrix(0, 5), 11);
+  BOOST_REQUIRE_EQUAL(matrix(1, 5), 12);
+  BOOST_REQUIRE_EQUAL(matrix(2, 5), 3);
+  BOOST_REQUIRE_EQUAL(matrix(0, 6), 13);
+  BOOST_REQUIRE_EQUAL(matrix(1, 6), 14);
+  BOOST_REQUIRE_EQUAL(matrix(2, 6), 3);
 
   BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::numeric);
   BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::numeric);
@@ -998,6 +997,100 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
   BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
   BOOST_REQUIRE_EQUAL(info.UnmapString(3, 2), "confusion");
+
+  remove("test.csv");
+}
+
+BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest)
+{
+  fstream f;
+  f.open("test.csv", fstream::out);
+  f << "1, 2, hello" << endl;
+  f << "3, 4, goodbye" << endl;
+  f << "5, 6, coffee" << endl;
+  f << "7, 8, confusion" << endl;
+  f << "9, 10, hello" << endl;
+  f << "11, 12, 15" << endl;
+  f << "13, 14, confusion" << endl;
+  f.close();
+
+  // Load the test CSV.
+  arma::umat matrix;
+  DatasetInfo info;
+  data::Load("test.csv", matrix, info, true, false); // No transpose.
+
+  BOOST_REQUIRE_EQUAL(matrix.n_cols, 3);
+  BOOST_REQUIRE_EQUAL(matrix.n_rows, 7);
+
+  BOOST_REQUIRE_EQUAL(matrix(0, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(0, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(0, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(1, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(1, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(1, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(2, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(2, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(3, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(3, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(3, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(4, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(4, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(4, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
+  BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
+  BOOST_REQUIRE_EQUAL(matrix(5, 2), 15);
+  BOOST_REQUIRE_EQUAL(matrix(6, 0), 0);
+  BOOST_REQUIRE_EQUAL(matrix(6, 1), 1);
+  BOOST_REQUIRE_EQUAL(matrix(6, 2), 2);
+
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::categorical);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::categorical);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(2), Datatype::categorical);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(3), Datatype::categorical);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(4), Datatype::categorical);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(5), Datatype::numeric);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(6), Datatype::categorical);
+
+  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("2", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("hello", 0), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("3", 1), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("4", 1), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 1), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("5", 2), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("6", 2), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("7", 3), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("8", 3), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 3), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("9", 4), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("10", 4), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("hello", 4), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("13", 6), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("14", 6), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 6), 2);
+
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "2");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 0), "hello");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "3");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "4");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 1), "goodbye");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "5");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "6");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 3), "7");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 3), "8");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 3), "confusion");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 4), "9");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 4), "10");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 4), "hello");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 6), "13");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 6), "14");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 6), "confusion");
+
+  remove("test.csv");
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list