[mlpack-git] master: Add Load() overload with DatasetInfo. (15faffe)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:41:42 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 15faffe7a369bc78000f63bdc53e3137f81c2a18
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Sep 10 19:13:15 2015 +0000

    Add Load() overload with DatasetInfo.
    
    Partially tested -- I am moving development systems, so I need to check it in.


>---------------------------------------------------------------

15faffe7a369bc78000f63bdc53e3137f81c2a18
 src/mlpack/core.hpp                 |   1 -
 src/mlpack/core/data/load.hpp       |   1 +
 src/mlpack/core/data/load_impl.hpp  | 161 +++++++++++++++++++++++++-
 src/mlpack/tests/load_save_test.cpp | 224 +++++++++++++++++++++++++++++-------
 4 files changed, 339 insertions(+), 48 deletions(-)

diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp
index df7657b..46dc1ec 100644
--- a/src/mlpack/core.hpp
+++ b/src/mlpack/core.hpp
@@ -159,7 +159,6 @@
 #include <mlpack/core/data/load.hpp>
 #include <mlpack/core/data/save.hpp>
 #include <mlpack/core/data/normalize_labels.hpp>
-#include <mlpack/core/data/dataset_info.hpp>
 #include <mlpack/core/math/clamp.hpp>
 #include <mlpack/core/math/random.hpp>
 #include <mlpack/core/math/lin_alg.hpp>
diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 950fd5a..3e04e6a 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -14,6 +14,7 @@
 #include <string>
 
 #include "format.hpp"
+#include "dataset_info.hpp"
 
 namespace mlpack {
 namespace data /** Functions to load and save matrices and models. */ {
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 307a886..85badd4 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -18,6 +18,7 @@
 #include <boost/archive/xml_iarchive.hpp>
 #include <boost/archive/text_iarchive.hpp>
 #include <boost/archive/binary_iarchive.hpp>
+#include <boost/tokenizer.hpp>
 
 #include "serialization_shim.hpp"
 
@@ -293,20 +294,168 @@ bool Load(const std::string& filename,
     return false;
   }
 
-  bool unknownType = false;
-  arma::file_type loadType;
-  std::string stringType;
-
-  if (extension == "csv" || extension == "tsv")
+  if (extension == "csv" || extension == "tsv" || extension == "txt")
   {
+    // True if we're looking for commas; if false, we're looking for spaces.
+    bool commas = (extension == "csv");
+
+    std::string type;
+    if (extension == "csv")
+      type = "CSV data";
+    else
+      type = "raw ASCII-formatted data";
+
+    Log::Info << "Loading '" << filename << "' as " << type << ".  "
+        << std::flush;
+    std::string separators;
+    if (commas)
+      separators = ",";
+    else
+      separators = " \t";
+
+    // We'll load this as CSV (or CSV with spaces or tabs) according to
+    // RFC4180.  So the first thing to do is determine the size of the matrix.
+    std::string buffer;
+    size_t cols = 0;
+
+    std::getline(stream, buffer, '\n');
+    // Count commas and whitespace in the line, ignoring anything inside
+    // quotes.
+    typedef boost::tokenizer<boost::escaped_list_separator<char>> Tokenizer;
+    boost::escaped_list_separator<char> sep("\\", separators, "\"");
+    Tokenizer tok(buffer, sep);
+    for (Tokenizer::iterator i = tok.begin(); i != tok.end(); ++i)
+      ++cols;
+
+    // Now count the number of lines in the file.  We've already counted the
+    // first one.
+    size_t rows = 1;
+    stream.unsetf(std::ios_base::skipws);
+    rows += std::count(std::istream_iterator<char>(stream),
+        std::istream_iterator<char>(), '\n');
+
+    // Back up to see if the last character in the file is an empty line.
+    stream.unget();
+    std::cout << "last character is " << int(stream.peek()) << ".\n";
+    while (isspace(stream.peek()))
+    {
+      if (stream.peek() == '\n')
+      {
+        --rows;
+        break;
+      }
+      stream.unget();
+    }
+
+    // Now we have the size.  So resize our matrix.
+    if (transpose)
+      matrix.set_size(cols, rows);
+    else
+      matrix.set_size(rows, cols);
+
+    stream.close();
+    stream.open(filename, std::fstream::in);
 
+    // Extract line by line.
+    std::stringstream token;
+    size_t row = 0;
+    while (!stream.bad() && !stream.fail() && !stream.eof())
+    {
+      std::getline(stream, buffer, '\n');
+
+      // Look at each token.  Unfortunately we have to do this character by
+      // character, because things may be escaped in quotes.
+      Tokenizer lineTok(buffer, sep);
+      size_t col = 0;
+      for (Tokenizer::iterator it = lineTok.begin(); it != lineTok.end(); ++it)
+      {
+        // Attempt to extract as type eT.  If that fails, we'll assume it's a
+        // string and map it (which may involve retroactively mapping everything
+        // we've seen so far).
+        token.clear();
+        token.str(*it);
+
+        eT val = eT(0);
+        token >> val;
+
+        if (token.fail())
+        {
+          std::cout << "conversion failed\n";
+          // Conversion failed; but it may be a NaN or inf.  Armadillo has
+          // convenient functions to check.
+          if (!arma::diskio::convert_naninf(val, token.str()))
+          {
+            // We need to perform a mapping.
+            const size_t dim = (transpose) ? col : row;
+            if (info.Type(dim) == Datatype::numeric)
+            {
+              // We must map everything we have seen up to this point and change
+              // the values in the matrix.
+              if (transpose)
+              {
+                // Whatever we've seen so far has successfully mapped to an eT.
+                // So we need to print it back to a string.  We'll use
+                // Armadillo's functionality for that.
+                for (size_t i = 0; i < row; ++i)
+                {
+                  std::stringstream sstr;
+                  arma::arma_ostream::print_elem(sstr, matrix.at(i, col),
+                      false);
+                  eT newVal = info.MapString(sstr.str(), col);
+                  matrix.at(i, col) = newVal;
+                }
+              }
+              else
+              {
+                for (size_t i = 0; i < col; ++i)
+                {
+                  std::stringstream sstr;
+                  arma::arma_ostream::print_elem(sstr, matrix.at(row, i),
+                      false);
+                  eT newVal = info.MapString(sstr.str(), row);
+                  matrix.at(row, i) = newVal;
+                }
+              }
+            }
+
+            val = info.MapString(token.str(), dim);
+          }
+        }
+
+        if (transpose)
+          matrix(col, row) = val;
+        else
+          matrix(row, col) = val;
+
+        ++col;
+      }
+
+      ++row;
+    }
+
+    if (stream.bad() || stream.fail())
+      Log::Warn << "Failure reading file '" << filename << "'." << std::endl;
   }
-  else if (extension == "txt")
+  else
   {
+    // The type is unknown.
+    Timer::Stop("loading_data");
+    if (fatal)
+      Log::Fatal << "Unable to detect type of '" << filename << "'; "
+          << "incorrect extension?" << std::endl;
+    else
+      Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
+          << " Incorrect extension?" << std::endl;
 
+    return false;
   }
 
+  Log::Info << "Size is " << (transpose ? matrix.n_cols : matrix.n_rows)
+      << " x " << (transpose ? matrix.n_rows : matrix.n_cols) << ".\n";
+
   Timer::Stop("loading_data");
+
+  return true;
 }
 
 // Load a model from file.
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 27121ff..4047170 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -13,6 +13,7 @@
 
 using namespace mlpack;
 using namespace mlpack::data;
+using namespace std;
 
 BOOST_AUTO_TEST_SUITE(LoadSaveTest);
 
@@ -54,11 +55,11 @@ BOOST_AUTO_TEST_CASE(NotExistLoad)
  */
 BOOST_AUTO_TEST_CASE(LoadCSVTest)
 {
-  std::fstream f;
-  f.open("test_file.csv", std::fstream::out);
+  fstream f;
+  f.open("test_file.csv", fstream::out);
 
-  f << "1, 2, 3, 4" << std::endl;
-  f << "5, 6, 7, 8" << std::endl;
+  f << "1, 2, 3, 4" << endl;
+  f << "5, 6, 7, 8" << endl;
 
   f.close();
 
@@ -80,11 +81,11 @@ BOOST_AUTO_TEST_CASE(LoadCSVTest)
  */
 BOOST_AUTO_TEST_CASE(LoadTSVTest)
 {
-  std::fstream f;
-  f.open("test_file.csv", std::fstream::out);
+  fstream f;
+  f.open("test_file.csv", fstream::out);
 
-  f << "1\t2\t3\t4" << std::endl;
-  f << "5\t6\t7\t8" << std::endl;
+  f << "1\t2\t3\t4" << endl;
+  f << "5\t6\t7\t8" << endl;
 
   f.close();
 
@@ -106,11 +107,11 @@ BOOST_AUTO_TEST_CASE(LoadTSVTest)
  */
 BOOST_AUTO_TEST_CASE(LoadTSVExtensionTest)
 {
-  std::fstream f;
-  f.open("test_file.tsv", std::fstream::out);
+  fstream f;
+  f.open("test_file.tsv", fstream::out);
 
-  f << "1\t2\t3\t4" << std::endl;
-  f << "5\t6\t7\t8" << std::endl;
+  f << "1\t2\t3\t4" << endl;
+  f << "5\t6\t7\t8" << endl;
 
   f.close();
 
@@ -158,11 +159,11 @@ BOOST_AUTO_TEST_CASE(SaveCSVTest)
  */
 BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
 {
-  std::fstream f;
-  f.open("test_file.csv", std::fstream::out);
+  fstream f;
+  f.open("test_file.csv", fstream::out);
 
-  f << "1, 2, 3, 4" << std::endl;
-  f << "5, 6, 7, 8" << std::endl;
+  f << "1, 2, 3, 4" << endl;
+  f << "5, 6, 7, 8" << endl;
 
   f.close();
 
@@ -184,11 +185,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
  */
 BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
 {
-  std::fstream f;
-  f.open("test_file.csv", std::fstream::out);
+  fstream f;
+  f.open("test_file.csv", fstream::out);
 
-  f << "1\t2\t3\t4" << std::endl;
-  f << "5\t6\t7\t8" << std::endl;
+  f << "1\t2\t3\t4" << endl;
+  f << "5\t6\t7\t8" << endl;
 
   f.close();
 
@@ -210,11 +211,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
  */
 BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
 {
-  std::fstream f;
-  f.open("test_file.tsv", std::fstream::out);
+  fstream f;
+  f.open("test_file.tsv", fstream::out);
 
-  f << "1\t2\t3\t4" << std::endl;
-  f << "5\t6\t7\t8" << std::endl;
+  f << "1\t2\t3\t4" << endl;
+  f << "5\t6\t7\t8" << endl;
 
   f.close();
 
@@ -236,11 +237,11 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
  */
 BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
 {
-  std::fstream f;
-  f.open("test_file.csv", std::fstream::out);
+  fstream f;
+  f.open("test_file.csv", fstream::out);
 
-  f << "1, 3, 5, 7" << std::endl;
-  f << "2, 4, 6, 8" << std::endl;
+  f << "1, 3, 5, 7" << endl;
+  f << "2, 4, 6, 8" << endl;
 
   f.close();
 
@@ -338,11 +339,11 @@ BOOST_AUTO_TEST_CASE(SaveArmaASCIITest)
  */
 BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
 {
-  std::fstream f;
-  f.open("test_file.txt", std::fstream::out);
+  fstream f;
+  f.open("test_file.txt", fstream::out);
 
-  f << "1 2 3 4" << std::endl;
-  f << "5 6 7 8" << std::endl;
+  f << "1 2 3 4" << endl;
+  f << "5 6 7 8" << endl;
 
   f.close();
 
@@ -364,11 +365,11 @@ BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
  */
 BOOST_AUTO_TEST_CASE(LoadCSVTxtTest)
 {
-  std::fstream f;
-  f.open("test_file.txt", std::fstream::out);
+  fstream f;
+  f.open("test_file.txt", fstream::out);
 
-  f << "1, 2, 3, 4" << std::endl;
-  f << "5, 6, 7, 8" << std::endl;
+  f << "1, 2, 3, 4" << endl;
+  f << "5, 6, 7, 8" << endl;
 
   f.close();
 
@@ -709,7 +710,7 @@ BOOST_AUTO_TEST_CASE(NormalizeLabelTest)
 class TestInner
 {
  public:
-  TestInner(char c, std::string s) : c(c), s(s) { }
+  TestInner(char c, string s) : c(c), s(s) { }
 
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */)
@@ -720,7 +721,7 @@ class TestInner
 
   // Public members for testing.
   char c;
-  std::string s;
+  string s;
 };
 
 class Test
@@ -849,13 +850,154 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
   }
 
   // Get the mappings back.
-  const std::string& strFirst = di.UnmapString(first, 3);
-  const std::string& strSecond = di.UnmapString(second, 3);
-  const std::string& strThird = di.UnmapString(third, 3);
+  const string& strFirst = di.UnmapString(first, 3);
+  const string& strSecond = di.UnmapString(second, 3);
+  const string& strThird = di.UnmapString(third, 3);
 
   BOOST_REQUIRE_EQUAL(strFirst, "test_mapping_1");
   BOOST_REQUIRE_EQUAL(strSecond, "test_mapping_2");
   BOOST_REQUIRE_EQUAL(strThird, "test_mapping_3");
 }
 
+/**
+ * Test loading regular CSV with DatasetInfo.  Everything should be numeric.
+ */
+BOOST_AUTO_TEST_CASE(RegularCSVDatasetInfoLoad)
+{
+  vector<string> testFiles;
+  //testFiles.push_back("fake.csv");
+  //testFiles.push_back("german.csv");
+  testFiles.push_back("iris.txt");
+  testFiles.push_back("vc2.txt");
+  testFiles.push_back("johnson-8-4-4.csv");
+  testFiles.push_back("lars_dependent_y.csv");
+  testFiles.push_back("vc2_test_labels.txt");
+
+  for (size_t i = 0; i < testFiles.size(); ++i)
+  {
+    arma::mat one, two;
+    DatasetInfo info;
+    data::Load(testFiles[i], one);
+    data::Load(testFiles[i], two, info);
+
+    // Check that the matrices contain the same information.
+    BOOST_REQUIRE_EQUAL(one.n_elem, two.n_elem);
+    BOOST_REQUIRE_EQUAL(one.n_rows, two.n_rows);
+    BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
+    for (size_t i = 0; i < one.n_elem; ++i)
+    {
+      std::cout << "i " << i << ": one " << one[i] << " two " << two[i] <<
+".\n";
+      if (std::abs(one[i]) < 1e-8)
+        BOOST_REQUIRE_SMALL(two[i], 1e-8);
+      else
+        BOOST_REQUIRE_CLOSE(one[i], two[i], 1e-8);
+    }
+
+    // Check that all dimensions are numeric.
+    for (size_t i = 0; i < two.n_rows; ++i)
+      BOOST_REQUIRE_EQUAL((Datatype) info.Type(i), Datatype::numeric);
+  }
+}
+
+/**
+ * Test non-transposed loading of regular CSVs with DatasetInfo.  Everything
+ * should be numeric.
+ */
+BOOST_AUTO_TEST_CASE(NontransposedCSVDatasetInfoLoad)
+{
+  vector<string> testFiles;
+  testFiles.push_back("fake.csv");
+  testFiles.push_back("german.csv");
+  testFiles.push_back("iris.txt");
+  testFiles.push_back("vc2.txt");
+  testFiles.push_back("johnson-8-4-4.csv");
+  testFiles.push_back("lars_dependent_y.csv");
+  testFiles.push_back("vc2_test_labels.txt");
+
+  for (size_t i = 0; i < testFiles.size(); ++i)
+  {
+    arma::mat one, two;
+    DatasetInfo info;
+    data::Load(testFiles[i], one, true, false); // No transpose.
+    data::Load(testFiles[i], two, info, true, false);
+
+    // Check that the matrices contain the same information.
+    BOOST_REQUIRE_EQUAL(one.n_elem, two.n_elem);
+    BOOST_REQUIRE_EQUAL(one.n_rows, two.n_rows);
+    BOOST_REQUIRE_EQUAL(one.n_cols, two.n_cols);
+    for (size_t i = 0; i < one.n_elem; ++i)
+    {
+      if (std::abs(one[i]) < 1e-8)
+        BOOST_REQUIRE_SMALL(two[i], 1e-8);
+      else
+        BOOST_REQUIRE_CLOSE(one[i], two[i], 1e-8);
+    }
+
+    // Check that all dimensions are numeric.
+    for (size_t i = 0; i < two.n_rows; ++i)
+      BOOST_REQUIRE_EQUAL((Datatype) info.Type(i), Datatype::numeric);
+  }
+}
+
+/**
+ * Create a file with a categorical string feature, then load it.
+ */
+BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest)
+{
+  fstream f("test.csv");
+  f << "1, 2, hello" << endl;
+  f << "3, 4, goodbye" << endl;
+  f << "5, 6, coffee" << endl;
+  f << "7, 8, confusion" << endl;
+  f << "9, 10, hello" << endl;
+  f << "11, 12, confusion" << endl;
+  f << "13, 14, confusion" << endl;
+  f.close();
+
+  // Load the test CSV.
+  arma::umat matrix;
+  DatasetInfo info;
+  data::Load("test.csv", matrix, info);
+
+  BOOST_REQUIRE_EQUAL(matrix.n_cols, 7);
+  BOOST_REQUIRE_EQUAL(matrix.n_rows, 3);
+
+  BOOST_REQUIRE_EQUAL(matrix(0, 0), 1);
+  BOOST_REQUIRE_EQUAL(matrix(0, 1), 2);
+  BOOST_REQUIRE_EQUAL(matrix(0, 2), 0);
+  BOOST_REQUIRE_EQUAL(matrix(1, 0), 3);
+  BOOST_REQUIRE_EQUAL(matrix(1, 1), 4);
+  BOOST_REQUIRE_EQUAL(matrix(1, 2), 1);
+  BOOST_REQUIRE_EQUAL(matrix(2, 0), 5);
+  BOOST_REQUIRE_EQUAL(matrix(2, 1), 6);
+  BOOST_REQUIRE_EQUAL(matrix(2, 2), 2);
+  BOOST_REQUIRE_EQUAL(matrix(3, 0), 7);
+  BOOST_REQUIRE_EQUAL(matrix(3, 1), 8);
+  BOOST_REQUIRE_EQUAL(matrix(3, 2), 3);
+  BOOST_REQUIRE_EQUAL(matrix(4, 0), 9);
+  BOOST_REQUIRE_EQUAL(matrix(4, 1), 10);
+  BOOST_REQUIRE_EQUAL(matrix(4, 2), 0);
+  BOOST_REQUIRE_EQUAL(matrix(5, 0), 11);
+  BOOST_REQUIRE_EQUAL(matrix(5, 1), 12);
+  BOOST_REQUIRE_EQUAL(matrix(5, 2), 3);
+  BOOST_REQUIRE_EQUAL(matrix(6, 0), 13);
+  BOOST_REQUIRE_EQUAL(matrix(6, 1), 14);
+  BOOST_REQUIRE_EQUAL(matrix(6, 2), 3);
+
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(0), Datatype::numeric);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(1), Datatype::numeric);
+  BOOST_REQUIRE_EQUAL((Datatype) info.Type(2), Datatype::categorical);
+
+  BOOST_REQUIRE_EQUAL(info.MapString("hello", 2), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 2), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 2), 3);
+
+  BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "hello");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(2, 2), "coffee");
+  BOOST_REQUIRE_EQUAL(info.UnmapString(3, 2), "confusion");
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list