[mlpack-git] master: Add Load() overload with DatasetInfo. (15faffe)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:42 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 15faffe7a369bc78000f63bdc53e3137f81c2a18
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Sep 10 19:13:15 2015 +0000
Add Load() overload with DatasetInfo.
Partially tested -- I am moving development systems, so I need to check it in.
>---------------------------------------------------------------
15faffe7a369bc78000f63bdc53e3137f81c2a18
src/mlpack/core.hpp | 1 -
src/mlpack/core/data/load.hpp | 1 +
src/mlpack/core/data/load_impl.hpp | 161 +++++++++++++++++++++++++-
src/mlpack/tests/load_save_test.cpp | 224 +++++++++++++++++++++++++++++-------
4 files changed, 339 insertions(+), 48 deletions(-)
diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp
index df7657b..46dc1ec 100644
--- a/src/mlpack/core.hpp
+++ b/src/mlpack/core.hpp
@@ -159,7 +159,6 @@
#include <mlpack/core/data/load.hpp>
#include <mlpack/core/data/save.hpp>
#include <mlpack/core/data/normalize_labels.hpp>
-#include <mlpack/core/data/dataset_info.hpp>
#include <mlpack/core/math/clamp.hpp>
#include <mlpack/core/math/random.hpp>
#include <mlpack/core/math/lin_alg.hpp>
diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 950fd5a..3e04e6a 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -14,6 +14,7 @@
#include <string>
#include "format.hpp"
+#include "dataset_info.hpp"
namespace mlpack {
namespace data /** Functions to load and save matrices and models. */ {
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 307a886..85badd4 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -18,6 +18,7 @@
#include <boost/archive/xml_iarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/binary_iarchive.hpp>
+#include <boost/tokenizer.hpp>
#include "serialization_shim.hpp"
@@ -293,20 +294,168 @@ bool Load(const std::string& filename,
return false;
}
- bool unknownType = false;
- arma::file_type loadType;
- std::string stringType;
-
- if (extension == "csv" || extension == "tsv")
+ if (extension == "csv" || extension == "tsv" || extension == "txt")
{
+ // True if we're looking for commas; if false, we're looking for spaces.
+ bool commas = (extension == "csv");
+
+ std::string type;
+ if (extension == "csv")
+ type = "CSV data";
+ else
+ type = "raw ASCII-formatted data";
+
+ Log::Info << "Loading '" << filename << "' as " << type << ". "
+ << std::flush;
+ std::string separators;
+ if (commas)
+ separators = ",";
+ else
+ separators = " \t";
+
+ // We'll load this as CSV (or CSV with spaces or tabs) according to
+ // RFC4180. So the first thing to do is determine the size of the matrix.
+ std::string buffer;
+ size_t cols = 0;
+
+ std::getline(stream, buffer, '\n');
+ // Count commas and whitespace in the line, ignoring anything inside
+ // quotes.
+ typedef boost::tokenizer<boost::escaped_list_separator<char>> Tokenizer;
+ boost::escaped_list_separator<char> sep("\\", separators, "\"");
+ Tokenizer tok(buffer, sep);
+ for (Tokenizer::iterator i = tok.begin(); i != tok.end(); ++i)
+ ++cols;
+
+ // Now count the number of lines in the file. We've already counted the
+ // first one.
+ size_t rows = 1;
+ stream.unsetf(std::ios_base::skipws);
+ rows += std::count(std::istream_iterator<char>(stream),
+ std::istream_iterator<char>(), '\n');
+
+ // Back up to see if the last character in the file is an empty line.
+ stream.unget();
+ std::cout << "last character is " << int(stream.peek()) << ".\n";
+ while (isspace(stream.peek()))
+ {
+ if (stream.peek() == '\n')
+ {
+ --rows;
+ break;
+ }
+ stream.unget();
+ }
+
+ // Now we have the size. So resize our matrix.
+ if (transpose)
+ matrix.set_size(cols, rows);
+ else
+ matrix.set_size(rows, cols);
+
+ stream.close();
+ stream.open(filename, std::fstream::in);
+ // Extract line by line.
+ std::stringstream token;
+ size_t row = 0;
+ while (!stream.bad() && !stream.fail() && !stream.eof())
+ {
+ std::getline(stream, buffer, '\n');
+
+ // Look at each token. Unfortunately we have to do this character by
+ // character, because things may be escaped in quotes.
+ Tokenizer lineTok(buffer, sep);
+ size_t col = 0;
+ for (Tokenizer::iterator it = lineTok.begin(); it != lineTok.end(); ++it)
+ {
+ // Attempt to extract as type eT. If that fails, we'll assume it's a
+ // string and map it (which may involve retroactively mapping everything
+ // we've seen so far).
+ token.clear();
+ token.str(*it);
+
+ eT val = eT(0);
+ token >> val;
+
+ if (token.fail())
+ {
+ std::cout << "conversion failed\n";
+ // Conversion failed; but it may be a NaN or inf. Armadillo has
+ // convenient functions to check.
+ if (!arma::diskio::convert_naninf(val, token.str()))
+ {
+ // We need to perform a mapping.
+ const size_t dim = (transpose) ? col : row;
+ if (info.Type(dim) == Datatype::numeric)
+ {
+ // We must map everything we have seen up to this point and change
+ // the values in the matrix.
+ if (transpose)
+ {
+ // Whatever we've seen so far has successfully mapped to an eT.
+ // So we need to print it back to a string. We'll use
+ // Armadillo's functionality for that.
+ for (size_t i = 0; i < row; ++i)
+ {
+ std::stringstream sstr;
+ arma::arma_ostream::print_elem(sstr, matrix.at(i, col),
+ false);
+ eT newVal = info.MapString(sstr.str(), col);
+ matrix.at(i, col) = newVal;
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < col; ++i)
+ {
+ std::stringstream sstr;
+ arma::arma_ostream::print_elem(sstr, matrix.at(row, i),
+ false);
+ eT newVal = info.MapString(sstr.str(), row);
+ matrix.at(row, i) = newVal;
+ }
+ }
+ }
+
+ val = info.MapString(token.str(), dim);
+ }
+ }
+
+ if (transpose)
+ matrix(col, row) = val;
+ else
+ matrix(row, col) = val;
+
+ ++col;
+ }
+
+ ++row;
+ }
+
+ if (stream.bad() || stream.fail())
+ Log::Warn << "Failure reading file '" << filename << "'." << std::endl;
}
- else if (extension == "txt")
+ else
{
+ // The type is unknown.
+ Timer::Stop("loading_data");
+ if (fatal)
+ Log::Fatal << "Unable to detect type of '" << filename << "'; "
+ << "incorrect extension?" << std::endl;
+ else
+ Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
+ << " Incorrect extension?" << std::endl;
+ return false;
}
+ Log::Info << "Size is " << (transpose ? matrix.n_cols : matrix.n_rows)
+ << " x " << (transpose ? matrix.n_rows : matrix.n_cols) << ".\n";
+
Timer::Stop("loading_data");
+
+ return true;
}
// Load a model from file.
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 27121ff..4047170 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -13,6 +13,7 @@
using namespace mlpack;
using namespace mlpack::data;
+using namespace std;
BOOST_AUTO_TEST_SUITE(LoadSaveTest);
@@ -54,11 +55,11 @@ BOOST_AUTO_TEST_CASE(NotExistLoad)
*/
BOOST_AUTO_TEST_CASE(LoadCSVTest)
{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
+ fstream f;
+ f.open("test_file.csv", fstream::out);
- f << "1, 2, 3, 4" << std::endl;
- f << "5, 6, 7, 8" << std::endl;
+ f << "1, 2, 3, 4" << endl;
+ f << "5, 6, 7, 8" << endl;
f.close();
@@ -80,11 +81,11 @@ BOOST_AUTO_TEST_CASE(LoadCSVTest)
*/
BOOST_AUTO_TEST_CASE(LoadTSVTest)
{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
+ fstream f;
+ f.open("test_file.csv", fstream::out);
- f << "1\t2\t3\t4" << std::endl;
- f << "5\t6\t7\t8" << std::endl;
+ f << "1\t2\t3\t4" << endl;
+ f << "5\t6\t7\t8" << endl;
f.close();
@@ -106,11 +107,11 @@ BOOST_AUTO_TEST_CASE(LoadTSVTest)
*/
BOOST_AUTO_TEST_CASE(LoadTSVExtensionTest)
{
- std::fstream f;
- f.open("test_file.tsv", std::fstream::out);
+ fstream f;
+ f.open("test_file.tsv", fstream::out);
- f << "1\t2\t3\t4" << std::endl;
- f << "5\t6\t7\t8" << std::endl;
+ f << "1\t2\t3\t4" << endl;
+ f << "5\t6\t7\t8" << endl;
f.close();
@@ -158,11 +159,11 @@ BOOST_AUTO_TEST_CASE(SaveCSVTest)
*/
BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
+ fstream f;
+ f.open("test_file.csv", fstream::out);
- f << "1, 2, 3, 4" << std::endl;
- f << "5, 6, 7, 8" << std::endl;
+ f << "1, 2, 3, 4" << endl;
+ f << "5, 6, 7, 8" << endl;
f.close();
@@ -184,11 +185,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
*/
BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
+ fstream f;
+ f.open("test_file.csv", fstream::out);
- f << "1\t2\t3\t4" << std::endl;
- f << "5\t6\t7\t8" << std::endl;
+ f << "1\t2\t3\t4" << endl;
+ f << "5\t6\t7\t8" << endl;
f.close();
@@ -210,11 +211,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
*/
BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
{
- std::fstream f;
- f.open("test_file.tsv", std::fstream::out);
+ fstream f;
+ f.open("test_file.tsv", fstream::out);
- f << "1\t2\t3\t4" << std::endl;
- f << "5\t6\t7\t8" << std::endl;
+ f << "1\t2\t3\t4" << endl;
+ f << "5\t6\t7\t8" << endl;
f.close();
@@ -236,11 +237,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
*/
BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
+ fstream f;
+ f.open("test_file.csv", fstream::out);
- f << "1, 3, 5, 7" << std::endl;
- f << "2, 4, 6, 8" << std::endl;
+ f << "1, 3, 5, 7" << endl;
+ f << "2, 4, 6, 8" << endl;
f.close();
@@ -338,11 +339,11 @@ BOOST_AUTO_TEST_CASE(SaveArmaASCIITest)
*/
BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
{
- std::fstream f;
- f.open("test_file.txt", std::fstream::out);
+ fstream f;
+ f.open("test_file.txt", fstream::out);
- f << "1 2 3 4" << std::endl;
- f << "5 6 7 8" << std::endl;
+ f << "1 2 3 4" << endl;
+ f << "5 6 7 8" << endl;
f.close();
@@ -364,11 +365,11 @@ BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
*/
BOOST_AUTO_TEST_CASE(LoadCSVTxtTest)
{
- std::fstream f;
- f.open("test_file.txt", std::fstream::out);
+ fstream f;
+ f.open("test_file.txt", fstream::out);
- f << "1, 2, 3, 4" << std::endl;
- f << "5, 6, 7, 8" << std::endl;
+ f << "1, 2, 3, 4" << endl;
+ f << "5, 6, 7, 8" << endl;
f.close();
@@ -709,7 +710,7 @@ BOOST_AUTO_TEST_CASE(NormalizeLabelTest)
class TestInner
{
public:
- TestInner(char c, std::string s) : c(c), s(s) { }
+ TestInner(char c, string s) : c(c), s(s) { }
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */)
@@ -720,7 +721,7 @@ class TestInner
// Public members for testing.
char c;
- std::string s;
+ string s;
};
class Test
@@ -849,13 +850,154 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
}
// Get the mappings back.
- const std::string& strFirst = di.UnmapString(first, 3);
- const std::string& strSecond = di.UnmapString(second, 3);
- const std::string& strThird = di.UnmapString(third, 3);
+ const string& strFirst = di.UnmapString(first, 3);
+ const string& strSecond = di.UnmapString(second, 3);
+ const string& strThird = di.UnmapString(third, 3);
BOOST_REQUIRE_EQUAL(strFirst, "test_mapping_1");
BOOST_REQUIRE_EQUAL(strSecond, "test_mapping_2");
BOOST_REQUIRE_EQUAL(strThird, "test_mapping_3");
}
+/**
+ * Test loading regular CSV with DatasetInfo. Everything should be numeric.
+ */
+BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
+{
+ vector<string> testFiles;
+ //testFiles.push_back("fake.csv");
+ //testFiles.push_back("german.csv");
+ testFiles.push_back("iris.txt");
+ testFiles.push_back("vc2.txt");
+ testFiles.push_back("johnson-8-4-4.csv");
+ testFiles.push_back("lars_dependent_y.csv");
+ testFiles.push_back("vc2_test_labels.txt");
+
+ for (size_t i = 0; i < testFiles.size(); ++i)
+ {
+ arma::mat one, two;
+ DatasetInfo info;
+ data::Load(testFiles[i], one);
+ data::Load(testFiles[i], two, info);
+
+ // Check that the matrices contain the same information.
+ BOOST_REQUIRE_EQUAL(one.n_elem, two.n_elem);
+ BOOST_REQUIRE_EQUAL(one.n_rows, two.n_rows);
+ BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
+ for (size_t i = 0; i < one.n_elem; ++i)
+ {
+ std::cout << "i " << i << ": one " << one[i] << " two " << two[i] <<
+".\n";
+ if (std::abs(one[i]) < 1e-8)
+ BOOST_REQUIRE_SMALL(two[i], 1e-8);
+ else
+ BOOST_REQUIRE_CLOSE(one[i], two[i], 1e-8);
+ }
+
+ // Check that all dimensions are numeric.
+ for (size_t i = 0; i < two.n_rows; ++i)
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(i), Datatype::numeric);
+ }
+}
+
+/**
+ * Test non-transposed loading of regular CSVs with DatasetInfo. Everything
+ * should be numeric.
+ */
+BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
+{
+ vector<string> testFiles;
+ testFiles.push_back("fake.csv");
+ testFiles.push_back("german.csv");
+ testFiles.push_back("iris.txt");
+ testFiles.push_back("vc2.txt");
+ testFiles.push_back("johnson-8-4-4.csv");
+ testFiles.push_back("lars_dependent_y.csv");
+ testFiles.push_back("vc2_test_labels.txt");
+
+ for (size_t i = 0; i < testFiles.size(); ++i)
+ {
+ arma::mat one, two;
+ DatasetInfo info;
+ data::Load(testFiles[i], one, true, false); // No transpose.
+ data::Load(testFiles[i], two, info, true, false);
+
+ // Check that the matrices contain the same information.
+ BOOST_REQUIRE_EQUAL(one.n_elem, two.n_elem);
+ BOOST_REQUIRE_EQUAL(one.n_rows, two.n_rows);
+ BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
+ for (size_t i = 0; i < one.n_elem; ++i)
+ {
+ if (std::abs(one[i]) < 1e-8)
+ BOOST_REQUIRE_SMALL(two[i], 1e-8);
+ else
+ BOOST_REQUIRE_CLOSE(one[i], two[i], 1e-8);
+ }
+
+ // Check that all dimensions are numeric.
+ for (size_t i = 0; i < two.n_rows; ++i)
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(i), Datatype::numeric);
+ }
+}
+
+/**
+ * Create a file with a categorical string feature, then load it.
+ */
+BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
+{
+ fstream f("test.csv");
+ f << "1, 2, hello" << endl;
+ f << "3, 4, goodbye" << endl;
+ f << "5, 6, coffee" << endl;
+ f << "7, 8, confusion" << endl;
+ f << "9, 10, hello" << endl;
+ f << "11, 12, confusion" << endl;
+ f << "13, 14, confusion" << endl;
+ f.close();
+
+ // Load the test CSV.
+ arma::umat matrix;
+ DatasetInfo info;
+ data::Load("test.csv", matrix, info);
+
+ BOOST_REQUIRE_EQUAL(matrix.n_cols, 7);
+ BOOST_REQUIRE_EQUAL(matrix.n_rows, 3);
+
+ BOOST_REQUIRE_EQUAL(matrix(0, 0), 1);
+ BOOST_REQUIRE_EQUAL(matrix(0, 1), 2);
+ BOOST_REQUIRE_EQUAL(matrix(0, 2), 0);
+ BOOST_REQUIRE_EQUAL(matrix(1, 0), 3);
+ BOOST_REQUIRE_EQUAL(matrix(1, 1), 4);
+ BOOST_REQUIRE_EQUAL(matrix(1, 2), 1);
+ BOOST_REQUIRE_EQUAL(matrix(2, 0), 5);
+ BOOST_REQUIRE_EQUAL(matrix(2, 1), 6);
+ BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
+ BOOST_REQUIRE_EQUAL(matrix(3, 0), 7);
+ BOOST_REQUIRE_EQUAL(matrix(3, 1), 8);
+ BOOST_REQUIRE_EQUAL(matrix(3, 2), 3);
+ BOOST_REQUIRE_EQUAL(matrix(4, 0), 9);
+ BOOST_REQUIRE_EQUAL(matrix(4, 1), 10);
+ BOOST_REQUIRE_EQUAL(matrix(4, 2), 0);
+ BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
+ BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
+ BOOST_REQUIRE_EQUAL(matrix(5, 2), 3);
+ BOOST_REQUIRE_EQUAL(matrix(6, 0), 13);
+ BOOST_REQUIRE_EQUAL(matrix(6, 1), 14);
+ BOOST_REQUIRE_EQUAL(matrix(6, 2), 3);
+
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::numeric);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::numeric);
+ BOOST_REQUIRE_EQUAL((Datatype) info.Type(2), Datatype::categorical);
+
+ BOOST_REQUIRE_EQUAL(info.MapString("hello", 2), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 2), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString("confusion", 2), 3);
+
+ BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "hello");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
+ BOOST_REQUIRE_EQUAL(info.UnmapString(3, 2), "confusion");
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list