[mlpack-git] master: Add Load() and Save() for models. (8514fd5)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:08 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 8514fd541b256cba80b4d974f3ebe76e28831625
Author: ryan <ryan at ratml.org>
Date: Wed Apr 15 13:08:31 2015 -0400
Add Load() and Save() for models.
>---------------------------------------------------------------
8514fd541b256cba80b4d974f3ebe76e28831625
src/mlpack/core/data/load.hpp | 40 +++++++++++++--
src/mlpack/core/data/load_impl.hpp | 102 +++++++++++++++++++++++++++++--------
src/mlpack/core/data/save.hpp | 38 ++++++++++++--
src/mlpack/core/data/save_impl.hpp | 85 ++++++++++++++++++++++++++++---
4 files changed, 228 insertions(+), 37 deletions(-)
diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 72b847e..c1bc1d2 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -14,7 +14,7 @@
#include <string>
namespace mlpack {
-namespace data /** Functions to load and save matrices. */ {
+namespace data /** Functions to load and save matrices and models. */ {
/**
* Loads a matrix from file, guessing the filetype from the extension. This
@@ -51,11 +51,43 @@ namespace data /** Functions to load and save matrices. */ {
template<typename eT>
bool Load(const std::string& filename,
arma::Mat<eT>& matrix,
- bool fatal = false,
+ const bool fatal = false,
bool transpose = true);
-}; // namespace data
-}; // namespace mlpack
+/**
+ * Load a model from a file, guessing the filetype from the extension, or,
+ * optionally, loading the specified format. If automatic extension detection
+ * is used and the filetype cannot be determined, an error will be given.
+ *
+ * The supported types of files are the same as what is supported by the
+ * boost::serialization library:
+ *
+ * - text, denoted by .txt
+ * - xml, denoted by .xml
+ * - binary, denoted by .bin
+ *
+ * The format parameter can take any of the values in the 'format' enum:
+ * 'format::autodetect', 'format::text', 'format::xml', and 'format::binary'.
+ * The autodetect functionality operates on the file extension (so, "file.txt"
+ * would be autodetected as text).
+ *
+ * The name parameter should be specified to indicate the name of the structure
+ * to be loaded. This should be the same as the name that was used to save the
+ * structure (otherwise, the loading procedure will fail).
+ *
+ * If the parameter 'fatal' is set to true, then an exception will be thrown in
+ * the event of load failure. Otherwise, the method will return false and the
+ * relevant error information will be printed to Log::Warn.
+ */
+template<typename T>
+bool Load(const std::string& filename,
+ T& t,
+ const std::string& name,
+ const bool fatal = false,
+ format f = format::autodetect);
+
+} // namespace data
+} // namespace mlpack
// Include implementation.
#include "load_impl.hpp"
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 9407346..95b6826 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -9,6 +9,7 @@
// In case it hasn't already been included.
#include "load.hpp"
+#include "extension.hpp"
#include <algorithm>
#include <mlpack/core/util/timers.hpp>
@@ -42,30 +43,13 @@ bool inline inplace_transpose(arma::Mat<eT>& X)
template<typename eT>
bool Load(const std::string& filename,
arma::Mat<eT>& matrix,
- bool fatal,
+ const bool fatal,
bool transpose)
{
Timer::Start("loading_data");
- // First we will try to discriminate by file extension.
- size_t ext = filename.rfind('.');
- if (ext == std::string::npos)
- {
- Timer::Stop("loading_data");
- if (fatal)
- Log::Fatal << "Cannot determine type of file '" << filename << "'; "
- << "no extension is present." << std::endl;
- else
- Log::Warn << "Cannot determine type of file '" << filename << "'; "
- << "no extension is present. Load failed." << std::endl;
-
- return false;
- }
-
- // Get the extension and force it to lowercase.
- std::string extension = filename.substr(ext + 1);
- std::transform(extension.begin(), extension.end(), extension.begin(),
- ::tolower);
+ // Get the extension.
+ std::string extension = Extension(filename);
// Catch nonexistent files by opening the stream ourselves.
std::fstream stream;
@@ -253,7 +237,81 @@ bool Load(const std::string& filename,
return success;
}
-}; // namespace data
-}; // namespace mlpack
+// Load a model from file.
+template<typename T>
+bool Load(const std::string& filename,
+ T& t,
+ const std::string& name,
+ const bool fatal,
+ format f)
+{
+ if (f == format::autodetect)
+ {
+ std::string extension = Extension(filename);
+
+ if (extension == "xml")
+ f = format::xml;
+ else if (extension == "bin")
+ f = format::binary;
+ else if (extension == "txt")
+ f = format::text;
+ else
+ {
+ if (fatal)
+ Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
+ << " extension?" << std::endl;
+ else
+ Log::Warn << "Unable to detect type of '" << filename << "'; load "
+ << "failed. Incorrect extension?" << std::endl;
+
+ return false;
+ }
+ }
+
+ // Now load the given format.
+ std::ifstream ifs(filename);
+ if (!ifs.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Unable to open file '" << filename << "'." << std::endl;
+ else
+ Log::Warn << "Unable to open file '" << filename << "'." << std::endl;
+
+ return false;
+ }
+
+ try
+ {
+ if (f == format::xml)
+ {
+ boost::archive::xml_iarchive ar(ifs);
+ ar >> util::CreateNVP(t, name);
+ }
+ else if (f == format::text)
+ {
+ boost::archive::text_iarchive ar(ifs);
+ ar >> util::CreateNVP(t, name);
+ }
+ else if (f == format::binary)
+ {
+ boost::archive::binary_iarchive ar(ifs);
+ ar >> util::CreateNVP(t, name);
+ }
+
+ return true;
+ }
+ catch (boost::serialization::archive_exception& e)
+ {
+ if (fatal)
+ Log::Fatal << e.what() << std::endl;
+ else
+ Log::Warn << e.what() << std::endl;
+
+ return false;
+ }
+}
+
+} // namespace data
+} // namespace mlpack
#endif
diff --git a/src/mlpack/core/data/save.hpp b/src/mlpack/core/data/save.hpp
index 840b455..10b64b1 100644
--- a/src/mlpack/core/data/save.hpp
+++ b/src/mlpack/core/data/save.hpp
@@ -48,11 +48,43 @@ namespace data /** Functions to load and save matrices. */ {
template<typename eT>
bool Save(const std::string& filename,
const arma::Mat<eT>& matrix,
- bool fatal = false,
+ const bool fatal = false,
bool transpose = true);
-}; // namespace data
-}; // namespace mlpack
+/**
+ * Saves a model to file, guessing the filetype from the extension, or,
+ * optionally, saving the specified format. If automatic extension detection is
+ * used and the filetype cannot be determined, and error will be given.
+ *
+ * The supported types of files are the same as what is supported by the
+ * boost::serialization library:
+ *
+ * - text, denoted by .txt
+ * - xml, denoted by .xml
+ * - binary, denoted by .bin
+ *
+ * The format parameter can take any of the values in the 'format' enum:
+ * 'format::autodetect', 'format::text', 'format::xml', and 'format::binary'.
+ * The autodetect functionality operates on the file extension (so, "file.txt"
+ * would be autodetected as text).
+ *
+ * The name parameter should be specified to indicate the name of the structure
+ * to be saved. If Load() is later called on the generated file, the name used
+ * to load should be the same as the name used for this call to Save().
+ *
+ * If the parameter 'fatal' is set to true, then an exception will be thrown in
+ * the event of a save failure. Otherwise, the method will return false and the
+ * relevant error information will be printed to Log::Warn.
+ */
+template<typename T>
+bool Save(const std::string& filename,
+ T& t,
+ const std::string& name,
+ const bool fatal = false,
+ format f = format::autodetect);
+
+} // namespace data
+} // namespace mlpack
// Include implementation.
#include "save_impl.hpp"
diff --git a/src/mlpack/core/data/save_impl.hpp b/src/mlpack/core/data/save_impl.hpp
index 2581e16..019ff83 100644
--- a/src/mlpack/core/data/save_impl.hpp
+++ b/src/mlpack/core/data/save_impl.hpp
@@ -16,14 +16,14 @@ namespace data {
template<typename eT>
bool Save(const std::string& filename,
const arma::Mat<eT>& matrix,
- bool fatal,
+ const bool fatal,
bool transpose)
{
Timer::Start("saving_data");
// First we will try to discriminate by file extension.
- size_t ext = filename.rfind('.');
- if (ext == std::string::npos)
+ std::string extension = Extension(filename);
+ if (extension == "")
{
Timer::Stop("saving_data");
if (fatal)
@@ -36,9 +36,6 @@ bool Save(const std::string& filename,
return false;
}
- // Get the actual extension.
- std::string extension = filename.substr(ext + 1);
-
// Catch errors opening the file.
std::fstream stream;
stream.open(filename.c_str(), std::fstream::out);
@@ -161,7 +158,79 @@ bool Save(const std::string& filename,
return true;
}
-}; // namespace data
-}; // namespace mlpack
+//! Save a model to file.
+template<typename T>
+bool Save(const std::string& filename,
+ T& t,
+ const std::string& name,
+ const bool fatal,
+ format f)
+{
+ if (f == format::autodetect)
+ {
+ std::string extension = Extension(filename);
+
+ if (extension == "xml")
+ f = format::xml;
+ else if (extension == "bin")
+ f = format::binary;
+ else if (extension == "txt")
+ f = format::text;
+ else
+ {
+ if (fatal)
+ Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
+ << " extension?" << std::endl;
+ else
+ Log::Warn << "Unable to detect type of '" << filename << "'; load "
+ << "failed. Incorrect extension?" << std::endl;
+
+ return false;
+ }
+ }
+
+ // Open the file to save to.
+ std::ofstream ofs(filename);
+ if (!ofs.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Unable to open file '" << filename << "'." << std::endl;
+ else
+ Log::Warn << "Unable to open file '" << filename << "'." << std::endl;
+
+ return false;
+ }
+
+ try
+ {
+ if (f == format::xml)
+ {
+ boost::archive::xml_oarchive ar(ofs);
+ ar << util::CreateNVP(t, name);
+ }
+ else if (f == format::text)
+ {
+ boost::archive::text_oarchive ar(ofs);
+ ar << util::CreateNVP(t, name);
+ }
+ else if (f == format::binary)
+ {
+ boost::archive::binary_oarchive ar(ofs);
+ ar << util::CreateNVP(t, name);
+ }
+ }
+ catch (boost::serialization::archive_exception& e)
+ {
+ if (fatal)
+ Log::Fatal << e.what() << std::endl;
+ else
+ Log::Warn << e.what() << std::endl;
+
+ return false;
+ }
+}
+
+} // namespace data
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list