[mlpack-git] master: fix #449 (d6a5f70)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:52:57 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6

>---------------------------------------------------------------

commit d6a5f7058a4afe4c7fcb2070060e470b2ce70521
Author: Qiang Kou <qkou at umail.iu.edu>
Date:   Fri Sep 4 17:46:20 2015 -0400

    fix #449


>---------------------------------------------------------------

d6a5f7058a4afe4c7fcb2070060e470b2ce70521
 src/mlpack/core/data/load_impl.hpp  | 19 ++++++++++++--
 src/mlpack/tests/load_save_test.cpp | 52 +++++++++++++++++++++++++++++++++++++
 2 files changed, 69 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 53d68a3..fbc1eaa 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -80,8 +80,23 @@ bool Load(const std::string& filename,
 
   if (extension == "csv")
   {
-    loadType = arma::csv_ascii;
-    stringType = "CSV data";
+    loadType = arma::diskio::guess_file_type(stream);
+    if (loadType == arma::csv_ascii) 
+    {
+      stringType = "CSV data";
+    }
+    else if (loadType == arma::raw_ascii) // .csv file can be tsv
+    {
+      Log::Warn << "'" << filename << "' is not a standard csv file."
+          << std::endl;
+      stringType = "raw ASCII formatted data";
+    }
+    else
+    {
+      unknownType = true;
+      loadType = arma::raw_binary; // Won't be used; prevent a warning.
+      stringType = "";
+    }
   }
   else if (extension == "txt")
   {
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 362ad35..a0838c1 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -75,6 +75,32 @@ BOOST_AUTO_TEST_CASE(LoadCSVTest)
 }
 
 /**
+ * Make sure a TSV is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadTSVTest)
+{
+  std::fstream f;
+  f.open("test_file.csv", std::fstream::out);
+
+  f << "1\t2\t3\t4" << std::endl;
+  f << "5\t6\t7\t8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.csv");
+}
+
+/**
  * Make sure a CSV is saved correctly.
  */
 BOOST_AUTO_TEST_CASE(SaveCSVTest)
@@ -127,6 +153,32 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
 }
 
 /**
+ * Make sure TSVs can be loaded in transposed form.
+ */
+BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
+{
+  std::fstream f;
+  f.open("test_file.csv", std::fstream::out);
+
+  f << "1\t2\t3\t4" << std::endl;
+  f << "5\t6\t7\t8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.csv", test, false, true) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+
+  for (size_t i = 0; i < 8; ++i)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.csv");
+}
+
+/**
  * Make sure CSVs can be loaded in non-transposed form.
  */
 BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)



More information about the mlpack-git mailing list