[mlpack-git] master: fix #449 (d6a5f70)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:52:57 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6
>---------------------------------------------------------------
commit d6a5f7058a4afe4c7fcb2070060e470b2ce70521
Author: Qiang Kou <qkou at umail.iu.edu>
Date: Fri Sep 4 17:46:20 2015 -0400
fix #449
>---------------------------------------------------------------
d6a5f7058a4afe4c7fcb2070060e470b2ce70521
src/mlpack/core/data/load_impl.hpp | 19 ++++++++++++--
src/mlpack/tests/load_save_test.cpp | 52 +++++++++++++++++++++++++++++++++++++
2 files changed, 69 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 53d68a3..fbc1eaa 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -80,8 +80,23 @@ bool Load(const std::string& filename,
if (extension == "csv")
{
- loadType = arma::csv_ascii;
- stringType = "CSV data";
+ loadType = arma::diskio::guess_file_type(stream);
+ if (loadType == arma::csv_ascii)
+ {
+ stringType = "CSV data";
+ }
+ else if (loadType == arma::raw_ascii) // .csv file can be tsv
+ {
+ Log::Warn << "'" << filename << "' is not a standard csv file."
+ << std::endl;
+ stringType = "raw ASCII formatted data";
+ }
+ else
+ {
+ unknownType = true;
+ loadType = arma::raw_binary; // Won't be used; prevent a warning.
+ stringType = "";
+ }
}
else if (extension == "txt")
{
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 362ad35..a0838c1 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -75,6 +75,32 @@ BOOST_AUTO_TEST_CASE(LoadCSVTest)
}
/**
+ * Make sure a TSV is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadTSVTest)
+{
+ std::fstream f;
+ f.open("test_file.csv", std::fstream::out);
+
+ f << "1\t2\t3\t4" << std::endl;
+ f << "5\t6\t7\t8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
* Make sure a CSV is saved correctly.
*/
BOOST_AUTO_TEST_CASE(SaveCSVTest)
@@ -127,6 +153,32 @@ BOOST_AUTO_TEST_CASE(LoadTransposedCSVTest)
}
/**
+ * Make sure TSVs can be loaded in transposed form.
+ */
+BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
+{
+ std::fstream f;
+ f.open("test_file.csv", std::fstream::out);
+
+ f << "1\t2\t3\t4" << std::endl;
+ f << "5\t6\t7\t8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test, false, true) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+
+ for (size_t i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
* Make sure CSVs can be loaded in non-transposed form.
*/
BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
More information about the mlpack-git
mailing list