[mlpack-svn] r10122 - in mlpack/trunk/src/mlpack/core: . data

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 2 20:12:30 EDT 2011


Author: rcurtin
Date: 2011-11-02 20:12:30 -0400 (Wed, 02 Nov 2011)
New Revision: 10122

Added:
   mlpack/trunk/src/mlpack/core/data/
   mlpack/trunk/src/mlpack/core/data/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/data/load.hpp
   mlpack/trunk/src/mlpack/core/data/load_impl.hpp
   mlpack/trunk/src/mlpack/core/data/load_save_test.cpp
   mlpack/trunk/src/mlpack/core/data/save.hpp
   mlpack/trunk/src/mlpack/core/data/save_impl.hpp
Log:
Add data::Load() and data::Save() methods to wrap Armadillo functionality.
Sadly we can't seem to get away from that.


Added: mlpack/trunk/src/mlpack/core/data/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/data/CMakeLists.txt	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/CMakeLists.txt	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,30 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files that we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  load.hpp
+  load_impl.hpp
+  save.hpp
+  save_impl.hpp
+)
+
+# add directory name to sources
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+    set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+# test executable
+add_executable(load_save_test
+   load_save_test.cpp
+)
+# link dependencies of test executable
+target_link_libraries(load_save_test
+   mlpack
+   boost_unit_test_framework
+)

Added: mlpack/trunk/src/mlpack/core/data/load.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load.hpp	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,54 @@
+/**
+ * @file load.hpp
+ * @author Ryan Curtin
+ *
+ * Load an Armadillo matrix from file.  This is necessary because Armadillo does
+ * not transpose matrices on input, and it allows us to give better error
+ * output.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_HPP
+#define __MLPACK_CORE_DATA_LOAD_HPP
+
+#include <mlpack/core/io/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.h> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Loads a matrix from file, guessing the filetype from the extension.  This
+ * will transpose the matrix at load time.  If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ *  - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ *  - ASCII (raw_ascii), denoted by .txt
+ *  - Armadillo ASCII (arma_ascii), also denoted by .txt
+ *  - PGM (pgm_binary), denoted by .pgm
+ *  - PPM (ppm_binary), denoted by .ppm
+ *  - Raw binary (raw_binary), denoted by .bin
+ *  - Armadillo binary (arma_binary), denoted by .bin
+ *
+ * If the file extension is not one of those types, an error will be given.
+ * This is preferable to Armadillo's default behavior of loading an unknown
+ * filetype as raw_binary, which can have very confusing effects.
+ *
+ * @param filename Name of file to load.
+ * @param matrix Matrix to load contents of file into.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @return Boolean value indicating success or failure of load.
+ */
+template<typename eT>
+bool Load(const std::string& filename,
+          arma::Mat<eT>& matrix,
+          bool fatal = false);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "load_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/data/load_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load_impl.hpp	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,173 @@
+/**
+ * @file load_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templatized load() function defined in load.hpp.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_HPP
+#error "Don't include this file directly; include mlpack/core/data/load.hpp."
+#endif
+
+#ifndef __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+#define __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Load(const std::string& filename, arma::Mat<eT>& matrix, bool fatal)
+{
+  // First we will try to discriminate by file extension.
+  size_t ext = filename.rfind('.');
+  if (ext == std::string::npos)
+  {
+    if (fatal)
+      Log::Fatal << "Cannot determine type of file '" << filename << "'; "
+          << "no extension is present." << std::endl;
+    else
+      Log::Warn << "Cannot determine type of file '" << filename << "'; "
+          << "no extension is present.  Load failed." << std::endl;
+
+    return false;
+  }
+
+  std::string extension = filename.substr(ext + 1);
+
+  // Catch nonexistent files by opening the stream ourselves.
+  std::fstream stream;
+  stream.open(filename.c_str(), std::fstream::in);
+
+  if (!stream.is_open())
+  {
+    if (fatal)
+      Log::Fatal << "Cannot open file '" << filename << "'. " << std::endl;
+    else
+      Log::Warn << "Cannot open file '" << filename << "'; load failed."
+          << std::endl;
+
+    return false;
+  }
+
+  bool unknown_type = false;
+  arma::file_type load_type;
+  std::string string_type;
+
+  if (extension == "csv")
+  {
+    load_type = arma::csv_ascii;
+    string_type = "CSV data";
+  }
+  else if (extension == "txt")
+  {
+    // This could be raw ASCII or Armadillo ASCII (ASCII with size header).
+    // We'll let Armadillo do its guessing (although we have to check if it is
+    // arma_ascii ourselves) and see what we come up with.
+
+    // This is taken from load_auto_detect() in diskio_meat.hpp
+    const std::string ARMA_MAT_TXT = "ARMA_MAT_TXT";
+    char* raw_header = new char[ARMA_MAT_TXT.length() + 1];
+    std::streampos pos = stream.tellg();
+
+    stream.read(raw_header, std::streamsize(ARMA_MAT_TXT.length()));
+    raw_header[ARMA_MAT_TXT.length()] = '\0';
+    stream.clear();
+    stream.seekg(pos); // Reset stream position after peeking.
+
+    if (std::string(raw_header) == ARMA_MAT_TXT)
+    {
+      load_type = arma::arma_ascii;
+      string_type = "Armadillo ASCII formatted data";
+    }
+    else // It's not arma_ascii.  Now we let Armadillo guess.
+    {
+      load_type = arma::diskio::guess_file_type(stream);
+
+      if (load_type == arma::raw_ascii) // Raw ASCII (space-separated).
+        string_type = "raw ASCII formatted data";
+      else if (load_type == arma::csv_ascii) // CSV can be .txt too.
+        string_type = "CSV data";
+      else // Unknown .txt... we will throw an error.
+        unknown_type = true;
+    }
+
+    delete[] raw_header;
+  }
+  else if (extension == "bin")
+  {
+    // This could be raw binary or Armadillo binary (binary with header).  We
+    // will check to see if it is Armadillo binary.
+    const std::string ARMA_MAT_BIN = "ARMA_MAT_BIN";
+    char *raw_header = new char[ARMA_MAT_BIN.length() + 1];
+
+    std::streampos pos = stream.tellg();
+
+    stream.read(raw_header, std::streamsize(ARMA_MAT_BIN.length()));
+    raw_header[ARMA_MAT_BIN.length()] = '\0';
+    stream.clear();
+    stream.seekg(pos); // Reset stream position after peeking.
+
+    if (std::string(raw_header) == ARMA_MAT_BIN)
+    {
+      string_type = "Armadillo binary formatted data";
+      load_type = arma::arma_binary;
+    }
+    else // We can only assume it's raw binary.
+    {
+      string_type = "raw binary formatted data";
+      load_type = arma::raw_binary;
+    }
+
+    delete[] raw_header;
+  }
+  else if (extension == "pgm")
+  {
+    load_type = arma::pgm_binary;
+    string_type = "PGM data";
+  }
+  else // Unknown extension...
+  {
+    unknown_type = true;
+  }
+
+  // Provide error if we don't know the type.
+  if (unknown_type)
+  {
+    if (fatal)
+      Log::Fatal << "Unable to detect type of '" << filename << "'; "
+          << "incorrect extension?" << std::endl;
+    else
+      Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
+          << " Incorrect extension?" << std::endl;
+
+    return false;
+  }
+
+  // Try to load the file; but if it's raw_binary, it could be a problem.
+  if (load_type == arma::raw_binary)
+    Log::Warn << "Loading '" << filename << "' as " << string_type << "; "
+        << "but this may not be the actual filetype!" << std::endl;
+  else
+    Log::Info << "Loading '" << filename << "' as " << string_type << "."
+        << std::endl;
+
+  bool success = matrix.load(stream, load_type);
+
+  if (!success)
+  {
+    if (fatal)
+      Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
+    else
+      Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
+  }
+
+  // Now transpose the matrix.
+  matrix = trans(matrix);
+
+  // Finally, return the success indicator.
+  return success;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/core/data/load_save_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load_save_test.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load_save_test.cpp	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,311 @@
+/**
+ * @file load_save_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for data::Load() and data::Save().
+ */
+#include <sstream>
+
+#include "load.hpp"
+#include "save.hpp"
+
+#define BOOST_TEST_MODULE LoadSaveTest
+#include <boost/test/unit_test.hpp>
+
+using namespace mlpack;
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionLoad) {
+  arma::mat out;
+  BOOST_REQUIRE(data::Load("noextension", out) == false);
+}
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionSave) {
+  arma::mat out;
+  BOOST_REQUIRE(data::Save("noextension", out) == false);
+}
+
+/**
+ * Make sure load fails if the file does not exist.
+ */
+BOOST_AUTO_TEST_CASE(NotExistLoad) {
+  arma::mat out;
+  BOOST_REQUIRE(data::Load("nonexistentfile_______________.csv", out) == false);
+}
+
+/**
+ * Make sure a CSV is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTest) {
+  std::fstream f;
+  f.open("test_file.csv", std::fstream::out);
+
+  f << "1, 2, 3, 4" << std::endl;
+  f << "5, 6, 7, 8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.csv");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveCSVTest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  BOOST_REQUIRE(data::Save("test_file.csv", test) == true);
+
+  // Load it in and make sure it is the same.
+  BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.csv");
+}
+
+/**
+ * Make sure arma_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaASCIITest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  arma::mat testTrans = trans(test);
+  BOOST_REQUIRE(testTrans.save("test_file.txt", arma::arma_ascii));
+
+  BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.txt");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaASCIITest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  BOOST_REQUIRE(data::Save("test_file.txt", test) == true);
+
+  // Load it in and make sure it is the same.
+  BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.txt");
+}
+
+/**
+ * Make sure raw_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawASCIITest) {
+  std::fstream f;
+  f.open("test_file.txt", std::fstream::out);
+
+  f << "1 2 3 4" << std::endl;
+  f << "5 6 7 8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.txt");
+}
+
+/**
+ * Make sure CSV is loaded correctly as .txt.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTxtTest) {
+  std::fstream f;
+  f.open("test_file.txt", std::fstream::out);
+
+  f << "1, 2, 3, 4" << std::endl;
+  f << "5, 6, 7, 8" << std::endl;
+
+  f.close();
+
+  arma::mat test;
+  BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.txt");
+}
+
+/**
+ * Make sure arma_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaBinaryTest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  arma::mat testTrans = trans(test);
+  BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::arma_binary)
+      == true);
+
+  // Now reload through our interface.
+  BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.bin");
+}
+
+/**
+ * Make sure arma_binary is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaBinaryTest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  BOOST_REQUIRE(data::Save("test_file.bin", test) == true);
+
+  BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.bin");
+}
+
+/**
+ * Make sure raw_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawBinaryTest) {
+  arma::mat test = "1 2;"
+                   "3 4;"
+                   "5 6;"
+                   "7 8;";
+
+  arma::mat testTrans = trans(test);
+  BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::raw_binary)
+      == true);
+
+  // Now reload through our interface.
+  BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 1);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 8);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.bin");
+}
+
+/**
+ * Make sure load as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(LoadPGMBinaryTest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  arma::mat testTrans = trans(test);
+  BOOST_REQUIRE(testTrans.quiet_save("test_file.pgm", arma::pgm_binary)
+      == true);
+
+  // Now reload through our interface.
+  BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.pgm");
+}
+
+/**
+ * Make sure save as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(SavePGMBinaryTest) {
+  arma::mat test = "1 5;"
+                   "2 6;"
+                   "3 7;"
+                   "4 8;";
+
+  BOOST_REQUIRE(data::Save("test_file.pgm", test) == true);
+
+  // Now reload through our interface.
+  BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+  BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+  BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+  for (int i = 0; i < 8; i++)
+    BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+  // Remove the file.
+  remove("test_file.pgm");
+}

Added: mlpack/trunk/src/mlpack/core/data/save.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/save.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/save.hpp	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,52 @@
+/**
+ * @file save.hpp
+ * @author Ryan Curtin
+ *
+ * Save an Armadillo matrix to file.  This is necessary because Armadillo does
+ * not transpose matrices upon saving, and it allows us to give better error
+ * output.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_HPP
+#define __MLPACK_CORE_DATA_SAVE_HPP
+
+#include <mlpack/core/io/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.h> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Saves a matrix to file, guessing the filetype from the extension.  This
+ * will transpose the matrix at save time.  If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ *  - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ *  - ASCII (raw_ascii), denoted by .txt
+ *  - Armadillo ASCII (arma_ascii), also denoted by .txt
+ *  - PGM (pgm_binary), denoted by .pgm
+ *  - PPM (ppm_binary), denoted by .ppm
+ *  - Raw binary (raw_binary), denoted by .bin
+ *  - Armadillo binary (arma_binary), denoted by .bin
+ *
+ * If the file extension is not one of those types, an error will be given.
+ *
+ * @param filename Name of file to save to.
+ * @param matrix Matrix to save into file.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @return Boolean value indicating success or failure of save.
+ */
+template<typename eT>
+bool Save(const std::string& filename,
+          const arma::Mat<eT>& matrix,
+          bool fatal = false);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "save_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/data/save_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/save_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/save_impl.hpp	2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,117 @@
+/**
+ * @file save_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of save functionality.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_HPP
+#error "Don't include this file directly; include mlpack/core/data/save.hpp."
+#endif
+
+#ifndef __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+#define __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Save(const std::string& filename, const arma::Mat<eT>& matrix, bool fatal)
+{
+  // First we will try to discriminate by file extension.
+  size_t ext = filename.rfind('.');
+  if (ext == std::string::npos)
+  {
+    if (fatal)
+      Log::Fatal << "No extension given with filename '" << filename << "'; "
+          << "type unknown.  Save failed." << std::endl;
+    else
+      Log::Warn << "No extension given with filename '" << filename << "'; "
+          << "type unknown.  Save failed." << std::endl;
+
+    return false;
+  }
+
+  // Get the actual extension.
+  std::string extension = filename.substr(ext + 1);
+
+  // Catch errors opening the file.
+  std::fstream stream;
+  stream.open(filename.c_str(), std::fstream::out);
+
+  if (!stream.is_open())
+  {
+    if (fatal)
+      Log::Fatal << "Cannot open file '" << filename << "' for writing. "
+          << "Save failed." << std::endl;
+    else
+      Log::Warn << "Cannot open file '" << filename << "' for writing; save "
+          << "failed." << std::endl;
+
+    return false;
+  }
+
+  bool unknown_type = false;
+  arma::file_type save_type;
+  std::string string_type;
+
+  if (extension == "csv")
+  {
+    save_type = arma::csv_ascii;
+    string_type = "CSV data";
+  }
+  else if (extension == "txt")
+  {
+    save_type = arma::raw_ascii;
+    string_type = "raw ASCII formatted data";
+  }
+  else if (extension == "bin")
+  {
+    save_type = arma::arma_binary;
+    string_type = "Armadillo binary formatted data";
+  }
+  else if (extension == "pgm")
+  {
+    save_type = arma::pgm_binary;
+    string_type = "PGM data";
+  }
+  else
+  {
+    unknown_type = true;
+  }
+
+  // Provide error if we don't know the type.
+  if (unknown_type)
+  {
+    if (fatal)
+      Log::Fatal << "Unable to determine format to save to from filename '"
+          << filename << "'.  Save failed." << std::endl;
+    else
+      Log::Warn << "Unable to determine format to save to from filename '"
+          << filename << "'.  Save failed." << std::endl;
+  }
+
+  // Try to save the file.
+  Log::Info << "Saving " << string_type << " to '" << filename << "'."
+      << std::endl;
+
+  // Transpose the matrix.
+  arma::Mat<eT> tmp = trans(matrix);
+
+  if (!tmp.quiet_save(stream, save_type))
+  {
+    if (fatal)
+      Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
+    else
+      Log::Warn << "Save to '" << filename << "' failed." << std::endl;
+
+    return false;
+  }
+
+  // Finally return success.
+  return true;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list