[mlpack-svn] r10122 - in mlpack/trunk/src/mlpack/core: . data
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 2 20:12:30 EDT 2011
Author: rcurtin
Date: 2011-11-02 20:12:30 -0400 (Wed, 02 Nov 2011)
New Revision: 10122
Added:
mlpack/trunk/src/mlpack/core/data/
mlpack/trunk/src/mlpack/core/data/CMakeLists.txt
mlpack/trunk/src/mlpack/core/data/load.hpp
mlpack/trunk/src/mlpack/core/data/load_impl.hpp
mlpack/trunk/src/mlpack/core/data/load_save_test.cpp
mlpack/trunk/src/mlpack/core/data/save.hpp
mlpack/trunk/src/mlpack/core/data/save_impl.hpp
Log:
Add data::Load() and data::Save() methods to wrap Armadillo functionality.
Sadly we can't seem to get away from that.
Added: mlpack/trunk/src/mlpack/core/data/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/data/CMakeLists.txt (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/CMakeLists.txt 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,30 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files that we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ load.hpp
+ load_impl.hpp
+ save.hpp
+ save_impl.hpp
+)
+
+# add directory name to sources
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+# test executable
+add_executable(load_save_test
+ load_save_test.cpp
+)
+# link dependencies of test executable
+target_link_libraries(load_save_test
+ mlpack
+ boost_unit_test_framework
+)
Added: mlpack/trunk/src/mlpack/core/data/load.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load.hpp 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,54 @@
+/**
+ * @file load.hpp
+ * @author Ryan Curtin
+ *
+ * Load an Armadillo matrix from file. This is necessary because Armadillo does
+ * not transpose matrices on input, and it allows us to give better error
+ * output.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_HPP
+#define __MLPACK_CORE_DATA_LOAD_HPP
+
+#include <mlpack/core/io/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.h> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Loads a matrix from file, guessing the filetype from the extension. This
+ * will transpose the matrix at load time. If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ * - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ * - ASCII (raw_ascii), denoted by .txt
+ * - Armadillo ASCII (arma_ascii), also denoted by .txt
+ * - PGM (pgm_binary), denoted by .pgm
+ * - PPM (ppm_binary), denoted by .ppm
+ * - Raw binary (raw_binary), denoted by .bin
+ * - Armadillo binary (arma_binary), denoted by .bin
+ *
+ * If the file extension is not one of those types, an error will be given.
+ * This is preferable to Armadillo's default behavior of loading an unknown
+ * filetype as raw_binary, which can have very confusing effects.
+ *
+ * @param filename Name of file to load.
+ * @param matrix Matrix to load contents of file into.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @return Boolean value indicating success or failure of load.
+ */
+template<typename eT>
+bool Load(const std::string& filename,
+ arma::Mat<eT>& matrix,
+ bool fatal = false);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "load_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/core/data/load_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load_impl.hpp 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,173 @@
+/**
+ * @file load_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templatized load() function defined in load.hpp.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_HPP
+#error "Don't include this file directly; include mlpack/core/data/load.hpp."
+#endif
+
+#ifndef __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+#define __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Load(const std::string& filename, arma::Mat<eT>& matrix, bool fatal)
+{
+ // First we will try to discriminate by file extension.
+ size_t ext = filename.rfind('.');
+ if (ext == std::string::npos)
+ {
+ if (fatal)
+ Log::Fatal << "Cannot determine type of file '" << filename << "'; "
+ << "no extension is present." << std::endl;
+ else
+ Log::Warn << "Cannot determine type of file '" << filename << "'; "
+ << "no extension is present. Load failed." << std::endl;
+
+ return false;
+ }
+
+ std::string extension = filename.substr(ext + 1);
+
+ // Catch nonexistent files by opening the stream ourselves.
+ std::fstream stream;
+ stream.open(filename.c_str(), std::fstream::in);
+
+ if (!stream.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Cannot open file '" << filename << "'. " << std::endl;
+ else
+ Log::Warn << "Cannot open file '" << filename << "'; load failed."
+ << std::endl;
+
+ return false;
+ }
+
+ bool unknown_type = false;
+ arma::file_type load_type;
+ std::string string_type;
+
+ if (extension == "csv")
+ {
+ load_type = arma::csv_ascii;
+ string_type = "CSV data";
+ }
+ else if (extension == "txt")
+ {
+ // This could be raw ASCII or Armadillo ASCII (ASCII with size header).
+ // We'll let Armadillo do its guessing (although we have to check if it is
+ // arma_ascii ourselves) and see what we come up with.
+
+ // This is taken from load_auto_detect() in diskio_meat.hpp
+ const std::string ARMA_MAT_TXT = "ARMA_MAT_TXT";
+ char* raw_header = new char[ARMA_MAT_TXT.length() + 1];
+ std::streampos pos = stream.tellg();
+
+ stream.read(raw_header, std::streamsize(ARMA_MAT_TXT.length()));
+ raw_header[ARMA_MAT_TXT.length()] = '\0';
+ stream.clear();
+ stream.seekg(pos); // Reset stream position after peeking.
+
+ if (std::string(raw_header) == ARMA_MAT_TXT)
+ {
+ load_type = arma::arma_ascii;
+ string_type = "Armadillo ASCII formatted data";
+ }
+ else // It's not arma_ascii. Now we let Armadillo guess.
+ {
+ load_type = arma::diskio::guess_file_type(stream);
+
+ if (load_type == arma::raw_ascii) // Raw ASCII (space-separated).
+ string_type = "raw ASCII formatted data";
+ else if (load_type == arma::csv_ascii) // CSV can be .txt too.
+ string_type = "CSV data";
+ else // Unknown .txt... we will throw an error.
+ unknown_type = true;
+ }
+
+ delete[] raw_header;
+ }
+ else if (extension == "bin")
+ {
+ // This could be raw binary or Armadillo binary (binary with header). We
+ // will check to see if it is Armadillo binary.
+ const std::string ARMA_MAT_BIN = "ARMA_MAT_BIN";
+ char *raw_header = new char[ARMA_MAT_BIN.length() + 1];
+
+ std::streampos pos = stream.tellg();
+
+ stream.read(raw_header, std::streamsize(ARMA_MAT_BIN.length()));
+ raw_header[ARMA_MAT_BIN.length()] = '\0';
+ stream.clear();
+ stream.seekg(pos); // Reset stream position after peeking.
+
+ if (std::string(raw_header) == ARMA_MAT_BIN)
+ {
+ string_type = "Armadillo binary formatted data";
+ load_type = arma::arma_binary;
+ }
+ else // We can only assume it's raw binary.
+ {
+ string_type = "raw binary formatted data";
+ load_type = arma::raw_binary;
+ }
+
+ delete[] raw_header;
+ }
+ else if (extension == "pgm")
+ {
+ load_type = arma::pgm_binary;
+ string_type = "PGM data";
+ }
+ else // Unknown extension...
+ {
+ unknown_type = true;
+ }
+
+ // Provide error if we don't know the type.
+ if (unknown_type)
+ {
+ if (fatal)
+ Log::Fatal << "Unable to detect type of '" << filename << "'; "
+ << "incorrect extension?" << std::endl;
+ else
+ Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
+ << " Incorrect extension?" << std::endl;
+
+ return false;
+ }
+
+ // Try to load the file; but if it's raw_binary, it could be a problem.
+ if (load_type == arma::raw_binary)
+ Log::Warn << "Loading '" << filename << "' as " << string_type << "; "
+ << "but this may not be the actual filetype!" << std::endl;
+ else
+ Log::Info << "Loading '" << filename << "' as " << string_type << "."
+ << std::endl;
+
+ bool success = matrix.load(stream, load_type);
+
+ if (!success)
+ {
+ if (fatal)
+ Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
+ else
+ Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
+ }
+
+ // Now transpose the matrix.
+ matrix = trans(matrix);
+
+ // Finally, return the success indicator.
+ return success;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif
Added: mlpack/trunk/src/mlpack/core/data/load_save_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/load_save_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/load_save_test.cpp 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,311 @@
+/**
+ * @file load_save_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for data::Load() and data::Save().
+ */
+#include <sstream>
+
+#include "load.hpp"
+#include "save.hpp"
+
+#define BOOST_TEST_MODULE LoadSaveTest
+#include <boost/test/unit_test.hpp>
+
+using namespace mlpack;
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionLoad) {
+ arma::mat out;
+ BOOST_REQUIRE(data::Load("noextension", out) == false);
+}
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionSave) {
+ arma::mat out;
+ BOOST_REQUIRE(data::Save("noextension", out) == false);
+}
+
+/**
+ * Make sure load fails if the file does not exist.
+ */
+BOOST_AUTO_TEST_CASE(NotExistLoad) {
+ arma::mat out;
+ BOOST_REQUIRE(data::Load("nonexistentfile_______________.csv", out) == false);
+}
+
+/**
+ * Make sure a CSV is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTest) {
+ std::fstream f;
+ f.open("test_file.csv", std::fstream::out);
+
+ f << "1, 2, 3, 4" << std::endl;
+ f << "5, 6, 7, 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveCSVTest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.csv", test) == true);
+
+ // Load it in and make sure it is the same.
+ BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure arma_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaASCIITest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.save("test_file.txt", arma::arma_ascii));
+
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaASCIITest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.txt", test) == true);
+
+ // Load it in and make sure it is the same.
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure raw_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawASCIITest) {
+ std::fstream f;
+ f.open("test_file.txt", std::fstream::out);
+
+ f << "1 2 3 4" << std::endl;
+ f << "5 6 7 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure CSV is loaded correctly as .txt.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTxtTest) {
+ std::fstream f;
+ f.open("test_file.txt", std::fstream::out);
+
+ f << "1, 2, 3, 4" << std::endl;
+ f << "5, 6, 7, 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure arma_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaBinaryTest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::arma_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure arma_binary is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaBinaryTest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.bin", test) == true);
+
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure raw_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawBinaryTest) {
+ arma::mat test = "1 2;"
+ "3 4;"
+ "5 6;"
+ "7 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::raw_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 1);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 8);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure load as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(LoadPGMBinaryTest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.pgm", arma::pgm_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.pgm");
+}
+
+/**
+ * Make sure save as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(SavePGMBinaryTest) {
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.pgm", test) == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.pgm");
+}
Added: mlpack/trunk/src/mlpack/core/data/save.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/save.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/save.hpp 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,52 @@
+/**
+ * @file save.hpp
+ * @author Ryan Curtin
+ *
+ * Save an Armadillo matrix to file. This is necessary because Armadillo does
+ * not transpose matrices upon saving, and it allows us to give better error
+ * output.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_HPP
+#define __MLPACK_CORE_DATA_SAVE_HPP
+
+#include <mlpack/core/io/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.h> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Saves a matrix to file, guessing the filetype from the extension. This
+ * will transpose the matrix at save time. If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ * - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ * - ASCII (raw_ascii), denoted by .txt
+ * - Armadillo ASCII (arma_ascii), also denoted by .txt
+ * - PGM (pgm_binary), denoted by .pgm
+ * - PPM (ppm_binary), denoted by .ppm
+ * - Raw binary (raw_binary), denoted by .bin
+ * - Armadillo binary (arma_binary), denoted by .bin
+ *
+ * If the file extension is not one of those types, an error will be given.
+ *
+ * @param filename Name of file to save to.
+ * @param matrix Matrix to save into file.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @return Boolean value indicating success or failure of save.
+ */
+template<typename eT>
+bool Save(const std::string& filename,
+ const arma::Mat<eT>& matrix,
+ bool fatal = false);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "save_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/core/data/save_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/data/save_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/data/save_impl.hpp 2011-11-03 00:12:30 UTC (rev 10122)
@@ -0,0 +1,117 @@
+/**
+ * @file save_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of save functionality.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_HPP
+#error "Don't include this file directly; include mlpack/core/data/save.hpp."
+#endif
+
+#ifndef __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+#define __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Save(const std::string& filename, const arma::Mat<eT>& matrix, bool fatal)
+{
+ // First we will try to discriminate by file extension.
+ size_t ext = filename.rfind('.');
+ if (ext == std::string::npos)
+ {
+ if (fatal)
+ Log::Fatal << "No extension given with filename '" << filename << "'; "
+ << "type unknown. Save failed." << std::endl;
+ else
+ Log::Warn << "No extension given with filename '" << filename << "'; "
+ << "type unknown. Save failed." << std::endl;
+
+ return false;
+ }
+
+ // Get the actual extension.
+ std::string extension = filename.substr(ext + 1);
+
+ // Catch errors opening the file.
+ std::fstream stream;
+ stream.open(filename.c_str(), std::fstream::out);
+
+ if (!stream.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Cannot open file '" << filename << "' for writing. "
+ << "Save failed." << std::endl;
+ else
+ Log::Warn << "Cannot open file '" << filename << "' for writing; save "
+ << "failed." << std::endl;
+
+ return false;
+ }
+
+ bool unknown_type = false;
+ arma::file_type save_type;
+ std::string string_type;
+
+ if (extension == "csv")
+ {
+ save_type = arma::csv_ascii;
+ string_type = "CSV data";
+ }
+ else if (extension == "txt")
+ {
+ save_type = arma::raw_ascii;
+ string_type = "raw ASCII formatted data";
+ }
+ else if (extension == "bin")
+ {
+ save_type = arma::arma_binary;
+ string_type = "Armadillo binary formatted data";
+ }
+ else if (extension == "pgm")
+ {
+ save_type = arma::pgm_binary;
+ string_type = "PGM data";
+ }
+ else
+ {
+ unknown_type = true;
+ }
+
+ // Provide error if we don't know the type.
+ if (unknown_type)
+ {
+ if (fatal)
+ Log::Fatal << "Unable to determine format to save to from filename '"
+ << filename << "'. Save failed." << std::endl;
+ else
+ Log::Warn << "Unable to determine format to save to from filename '"
+ << filename << "'. Save failed." << std::endl;
+ }
+
+ // Try to save the file.
+ Log::Info << "Saving " << string_type << " to '" << filename << "'."
+ << std::endl;
+
+ // Transpose the matrix.
+ arma::Mat<eT> tmp = trans(matrix);
+
+ if (!tmp.quiet_save(stream, save_type))
+ {
+ if (fatal)
+ Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
+ else
+ Log::Warn << "Save to '" << filename << "' failed." << std::endl;
+
+ return false;
+ }
+
+ // Finally return success.
+ return true;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list