[mlpack-git] master: Allow .tsv extension too. (ee732ca)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:53:06 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6

>---------------------------------------------------------------

commit ee732cae82019c47a48fb7b2fb11325e1ee99367
Author: ryan <ryan at ratml.org>
Date:   Tue Sep 8 09:53:31 2015 -0400

    Allow .tsv extension too.


>---------------------------------------------------------------

ee732cae82019c47a48fb7b2fb11325e1ee99367
 src/mlpack/core/data/load_impl.hpp  | 12 ++++++---
 src/mlpack/tests/load_save_test.cpp | 52 +++++++++++++++++++++++++++++++++++++
 2 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index fbc1eaa..369f8bf 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -78,17 +78,21 @@ bool Load(const std::string& filename,
   arma::file_type loadType;
   std::string stringType;
 
-  if (extension == "csv")
+  if (extension == "csv" || extension == "tsv")
   {
     loadType = arma::diskio::guess_file_type(stream);
     if (loadType == arma::csv_ascii) 
     {
+      if (extension == "tsv")
+        Log::Warn << "'" << filename << "' is comma-separated, not "
+            "tab-separated!" << std::endl;
       stringType = "CSV data";
     }
-    else if (loadType == arma::raw_ascii) // .csv file can be tsv
+    else if (loadType == arma::raw_ascii) // .csv file can be tsv.
     {
-      Log::Warn << "'" << filename << "' is not a standard csv file."
-          << std::endl;
+      if (extension == "csv")
+        Log::Warn << "'" << filename << "' is not a standard csv file."
+            << std::endl;
       stringType = "raw ASCII formatted data";
     }
     else
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index a0838c1..646cdd1 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -101,6 +101,32 @@ BOOST_AUTO_TEST_CASE(LoadTSVTest)
 }
 
 /**
+ * Test TSV loading with .tsv extension.
+ */
+BOOST_AUTO_TEST_CASE(LoadTSVExtensionTest)
+{
+  std::fstream f;
+  f.open("test_file.tsv", std::fstream::out);
+
+  f << "1\t2\t3\t4" << std::endl;
+  f << "5\t6\t7\t8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.tsv", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.tsv");
+}
+
+/**
  * Make sure a CSV is saved correctly.
  */
 BOOST_AUTO_TEST_CASE(SaveCSVTest)
@@ -179,6 +205,32 @@ BOOST_AUTO_TEST_CASE(LoadTransposedTSVTest)
 }
 
 /**
+ * Check TSV loading with .tsv extension.
+ */
+BOOST_AUTO_TEST_CASE(LoadTransposedTSVExtensionTest)
+{
+  std::fstream f;
+  f.open("test_file.tsv", std::fstream::out);
+
+  f << "1\t2\t3\t4" << std::endl;
+  f << "5\t6\t7\t8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.tsv", test, false, true) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+
+  for (size_t i = 0; i < 8; ++i)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.tsv");
+}
+
+/**
  * Make sure CSVs can be loaded in non-transposed form.
  */
 BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)



More information about the mlpack-git mailing list