[mlpack-git] master: Add Load() and Save() for models. (8514fd5)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:08 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 8514fd541b256cba80b4d974f3ebe76e28831625
Author: ryan <ryan at ratml.org>
Date:   Wed Apr 15 13:08:31 2015 -0400

    Add Load() and Save() for models.


>---------------------------------------------------------------

8514fd541b256cba80b4d974f3ebe76e28831625
 src/mlpack/core/data/load.hpp      |  40 +++++++++++++--
 src/mlpack/core/data/load_impl.hpp | 102 +++++++++++++++++++++++++++++--------
 src/mlpack/core/data/save.hpp      |  38 ++++++++++++--
 src/mlpack/core/data/save_impl.hpp |  85 ++++++++++++++++++++++++++++---
 4 files changed, 228 insertions(+), 37 deletions(-)

diff --git a/src/mlpack/core/data/load.hpp b/src/mlpack/core/data/load.hpp
index 72b847e..c1bc1d2 100644
--- a/src/mlpack/core/data/load.hpp
+++ b/src/mlpack/core/data/load.hpp
@@ -14,7 +14,7 @@
 #include <string>
 
 namespace mlpack {
-namespace data /** Functions to load and save matrices. */ {
+namespace data /** Functions to load and save matrices and models. */ {
 
 /**
  * Loads a matrix from file, guessing the filetype from the extension.  This
@@ -51,11 +51,43 @@ namespace data /** Functions to load and save matrices. */ {
 template<typename eT>
 bool Load(const std::string& filename,
           arma::Mat<eT>& matrix,
-          bool fatal = false,
+          const bool fatal = false,
           bool transpose = true);
 
-}; // namespace data
-}; // namespace mlpack
+/**
+ * Load a model from a file, guessing the filetype from the extension, or,
+ * optionally, loading the specified format.  If automatic extension detection
+ * is used and the filetype cannot be determined, an error will be given.
+ *
+ * The supported types of files are the same as what is supported by the
+ * boost::serialization library:
+ *
+ *  - text, denoted by .txt
+ *  - xml, denoted by .xml
+ *  - binary, denoted by .bin
+ *
+ * The format parameter can take any of the values in the 'format' enum:
+ * 'format::autodetect', 'format::text', 'format::xml', and 'format::binary'.
+ * The autodetect functionality operates on the file extension (so, "file.txt"
+ * would be autodetected as text).
+ *
+ * The name parameter should be specified to indicate the name of the structure
+ * to be loaded.  This should be the same as the name that was used to save the
+ * structure (otherwise, the loading procedure will fail).
+ *
+ * If the parameter 'fatal' is set to true, then an exception will be thrown in
+ * the event of load failure.  Otherwise, the method will return false and the
+ * relevant error information will be printed to Log::Warn.
+ */
+template<typename T>
+bool Load(const std::string& filename,
+          T& t,
+          const std::string& name,
+          const bool fatal = false,
+          format f = format::autodetect);
+
+} // namespace data
+} // namespace mlpack
 
 // Include implementation.
 #include "load_impl.hpp"
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 9407346..95b6826 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -9,6 +9,7 @@
 
 // In case it hasn't already been included.
 #include "load.hpp"
+#include "extension.hpp"
 
 #include <algorithm>
 #include <mlpack/core/util/timers.hpp>
@@ -42,30 +43,13 @@ bool inline inplace_transpose(arma::Mat<eT>& X)
 template<typename eT>
 bool Load(const std::string& filename,
           arma::Mat<eT>& matrix,
-          bool fatal,
+          const bool fatal,
           bool transpose)
 {
   Timer::Start("loading_data");
 
-  // First we will try to discriminate by file extension.
-  size_t ext = filename.rfind('.');
-  if (ext == std::string::npos)
-  {
-    Timer::Stop("loading_data");
-    if (fatal)
-      Log::Fatal << "Cannot determine type of file '" << filename << "'; "
-          << "no extension is present." << std::endl;
-    else
-      Log::Warn << "Cannot determine type of file '" << filename << "'; "
-          << "no extension is present.  Load failed." << std::endl;
-
-    return false;
-  }
-
-  // Get the extension and force it to lowercase.
-  std::string extension = filename.substr(ext + 1);
-  std::transform(extension.begin(), extension.end(), extension.begin(),
-      ::tolower);
+  // Get the extension.
+  std::string extension = Extension(filename);
 
   // Catch nonexistent files by opening the stream ourselves.
   std::fstream stream;
@@ -253,7 +237,81 @@ bool Load(const std::string& filename,
   return success;
 }
 
-}; // namespace data
-}; // namespace mlpack
+// Load a model from file.
+template<typename T>
+bool Load(const std::string& filename,
+          T& t,
+          const std::string& name,
+          const bool fatal,
+          format f)
+{
+  if (f == format::autodetect)
+  {
+    std::string extension = Extension(filename);
+
+    if (extension == "xml")
+      f = format::xml;
+    else if (extension == "bin")
+      f = format::binary;
+    else if (extension == "txt")
+      f = format::text;
+    else
+    {
+      if (fatal)
+        Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
+            << " extension?" << std::endl;
+      else
+        Log::Warn << "Unable to detect type of '" << filename << "'; load "
+            << "failed.  Incorrect extension?" << std::endl;
+
+      return false;
+    }
+  }
+
+  // Now load the given format.
+  std::ifstream ifs(filename);
+  if (!ifs.is_open())
+  {
+    if (fatal)
+      Log::Fatal << "Unable to open file '" << filename << "'." << std::endl;
+    else
+      Log::Warn << "Unable to open file '" << filename << "'." << std::endl;
+
+    return false;
+  }
+
+  try
+  {
+    if (f == format::xml)
+    {
+      boost::archive::xml_iarchive ar(ifs);
+      ar >> util::CreateNVP(t, name);
+    }
+    else if (f == format::text)
+    {
+      boost::archive::text_iarchive ar(ifs);
+      ar >> util::CreateNVP(t, name);
+    }
+    else if (f == format::binary)
+    {
+      boost::archive::binary_iarchive ar(ifs);
+      ar >> util::CreateNVP(t, name);
+    }
+
+    return true;
+  }
+  catch (boost::serialization::archive_exception& e)
+  {
+    if (fatal)
+      Log::Fatal << e.what() << std::endl;
+    else
+      Log::Warn << e.what() << std::endl;
+
+    return false;
+  }
+}
+
+} // namespace data
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/core/data/save.hpp b/src/mlpack/core/data/save.hpp
index 840b455..10b64b1 100644
--- a/src/mlpack/core/data/save.hpp
+++ b/src/mlpack/core/data/save.hpp
@@ -48,11 +48,43 @@ namespace data /** Functions to load and save matrices. */ {
 template<typename eT>
 bool Save(const std::string& filename,
           const arma::Mat<eT>& matrix,
-          bool fatal = false,
+          const bool fatal = false,
           bool transpose = true);
 
-}; // namespace data
-}; // namespace mlpack
+/**
+ * Saves a model to file, guessing the filetype from the extension, or,
+ * optionally, saving the specified format.  If automatic extension detection is
+ * used and the filetype cannot be determined, and error will be given.
+ *
+ * The supported types of files are the same as what is supported by the
+ * boost::serialization library:
+ *
+ *  - text, denoted by .txt
+ *  - xml, denoted by .xml
+ *  - binary, denoted by .bin
+ *
+ * The format parameter can take any of the values in the 'format' enum:
+ * 'format::autodetect', 'format::text', 'format::xml', and 'format::binary'.
+ * The autodetect functionality operates on the file extension (so, "file.txt"
+ * would be autodetected as text).
+ *
+ * The name parameter should be specified to indicate the name of the structure
+ * to be saved.  If Load() is later called on the generated file, the name used
+ * to load should be the same as the name used for this call to Save().
+ *
+ * If the parameter 'fatal' is set to true, then an exception will be thrown in
+ * the event of a save failure.  Otherwise, the method will return false and the
+ * relevant error information will be printed to Log::Warn.
+ */
+template<typename T>
+bool Save(const std::string& filename,
+          T& t,
+          const std::string& name,
+          const bool fatal = false,
+          format f = format::autodetect);
+
+} // namespace data
+} // namespace mlpack
 
 // Include implementation.
 #include "save_impl.hpp"
diff --git a/src/mlpack/core/data/save_impl.hpp b/src/mlpack/core/data/save_impl.hpp
index 2581e16..019ff83 100644
--- a/src/mlpack/core/data/save_impl.hpp
+++ b/src/mlpack/core/data/save_impl.hpp
@@ -16,14 +16,14 @@ namespace data {
 template<typename eT>
 bool Save(const std::string& filename,
           const arma::Mat<eT>& matrix,
-          bool fatal,
+          const bool fatal,
           bool transpose)
 {
   Timer::Start("saving_data");
 
   // First we will try to discriminate by file extension.
-  size_t ext = filename.rfind('.');
-  if (ext == std::string::npos)
+  std::string extension = Extension(filename);
+  if (extension == "")
   {
     Timer::Stop("saving_data");
     if (fatal)
@@ -36,9 +36,6 @@ bool Save(const std::string& filename,
     return false;
   }
 
-  // Get the actual extension.
-  std::string extension = filename.substr(ext + 1);
-
   // Catch errors opening the file.
   std::fstream stream;
   stream.open(filename.c_str(), std::fstream::out);
@@ -161,7 +158,79 @@ bool Save(const std::string& filename,
   return true;
 }
 
-}; // namespace data
-}; // namespace mlpack
+//! Save a model to file.
+template<typename T>
+bool Save(const std::string& filename,
+          T& t,
+          const std::string& name,
+          const bool fatal,
+          format f)
+{
+  if (f == format::autodetect)
+  {
+    std::string extension = Extension(filename);
+
+    if (extension == "xml")
+      f = format::xml;
+    else if (extension == "bin")
+      f = format::binary;
+    else if (extension == "txt")
+      f = format::text;
+    else
+    {
+      if (fatal)
+        Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
+            << " extension?" << std::endl;
+      else
+        Log::Warn << "Unable to detect type of '" << filename << "'; load "
+            << "failed.  Incorrect extension?" << std::endl;
+
+      return false;
+    }
+  }
+
+  // Open the file to save to.
+  std::ofstream ofs(filename);
+  if (!ofs.is_open())
+  {
+    if (fatal)
+      Log::Fatal << "Unable to open file '" << filename << "'." << std::endl;
+    else
+      Log::Warn << "Unable to open file '" << filename << "'." << std::endl;
+
+    return false;
+  }
+
+  try
+  {
+    if (f == format::xml)
+    {
+      boost::archive::xml_oarchive ar(ofs);
+      ar << util::CreateNVP(t, name);
+    }
+    else if (f == format::text)
+    {
+      boost::archive::text_oarchive ar(ofs);
+      ar << util::CreateNVP(t, name);
+    }
+    else if (f == format::binary)
+    {
+      boost::archive::binary_oarchive ar(ofs);
+      ar << util::CreateNVP(t, name);
+    }
+  }
+  catch (boost::serialization::archive_exception& e)
+  {
+    if (fatal)
+      Log::Fatal << e.what() << std::endl;
+    else
+      Log::Warn << e.what() << std::endl;
+
+    return false;
+  }
+}
+
+} // namespace data
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list