[mlpack-git] master: Allow .tsv extension too. (ee732ca)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:53:06 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6
>---------------------------------------------------------------
commit ee732cae82019c47a48fb7b2fb11325e1ee99367
Author: ryan <ryan at ratml.org>
Date: Tue Sep 8 09:53:31 2015 -0400
Allow .tsv extension too.
>---------------------------------------------------------------
ee732cae82019c47a48fb7b2fb11325e1ee99367
src/mlpack/core/data/load_impl.hpp | 12 ++++++---
src/mlpack/tests/load_save_test.cpp | 52 +++++++++++++++++++++++++++++++++++++
2 files changed, 60 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index fbc1eaa..369f8bf 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -78,17 +78,21 @@ bool Load(const std::string& filename,
arma::file_type loadType;
std::string stringType;
- if (extension == "csv")
+ if (extension == "csv" || extension == "tsv")
{
loadType = arma::diskio::guess_file_type(stream);
if (loadType == arma::csv_ascii)
{
+ if (extension == "tsv")
+ Log::Warn << "'" << filename << "' is comma-separated, not "
+ "tab-separated!" << std::endl;
stringType = "CSV data";
}
- else if (loadType == arma::raw_ascii) // .csv file can be tsv
+ else if (loadType == arma::raw_ascii) // .csv file can be tsv.
{
- Log::Warn << "'" << filename << "' is not a standard csv file."
- << std::endl;
+ if (extension == "csv")
+ Log::Warn << "'" << filename << "' is not a standard csv file."
+ << std::endl;
stringType = "raw ASCII formatted data";
}
else
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index a0838c1..646cdd1 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -101,6 +101,32 @@ BOOST_AUTO_TEST_CASE(LoadTSVTest)
}
/**
+ * Test TSV loading with .tsv extension.
+ */
+BOOST_AUTO_TEST_CASE(LoadTSVExtensionTest)
+{
+ std::fstream f;
+ f.open("test_file.tsv", std::fstream::out);
+
+ f << "1\t2\t3\t4" << std::endl;
+ f << "5\t6\t7\t8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.tsv", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.tsv");
+}
+
+/**
* Make sure a CSV is saved correctly.
*/
BOOST_AUTO_TEST_CASE(SaveCSVTest)
@@ -179,6 +205,32 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
}
/**
+ * Check TSV loading with .tsv extension.
+ */
+BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
+{
+ std::fstream f;
+ f.open("test_file.tsv", std::fstream::out);
+
+ f << "1\t2\t3\t4" << std::endl;
+ f << "5\t6\t7\t8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.tsv", test, false, true) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+
+ for (size_t i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.tsv");
+}
+
+/**
* Make sure CSVs can be loaded in non-transposed form.
*/
BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
More information about the mlpack-git
mailing list