[mlpack-git] master: Force high-precision Armadillo saves. This will fix the failing tests in GMMTest and HMMTest. We now use enable_if_c<> to catch when an Armadillo dense or sparse object is being saved/loaded, and use Armadillo's diskio functionality to write to (or read from) a stream. This allows high-precision saving, whereas just printing via `stream << mat` is pretty low-precision and can't be changed with setwidth(). (fd599b9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Feb 16 15:46:07 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c42ca0ee5228ba83123767ea8899fcf6e4817b42...fd599b9f713be002b62278bfe417f56856717083

>---------------------------------------------------------------

commit fd599b9f713be002b62278bfe417f56856717083
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 16 15:44:36 2015 -0500

    Force high-precision Armadillo saves.
    This will fix the failing tests in GMMTest and HMMTest.
    We now use enable_if_c<> to catch when an Armadillo dense or sparse object is
    being saved/loaded, and use Armadillo's diskio functionality to write to (or
    read from) a stream.  This allows high-precision saving, whereas just printing
    via `stream << mat` is pretty low-precision and can't be changed with
    setwidth().


>---------------------------------------------------------------

fd599b9f713be002b62278bfe417f56856717083
 src/mlpack/core/util/save_restore_utility.cpp      |  32 -----
 src/mlpack/core/util/save_restore_utility.hpp      |  55 ++++++--
 src/mlpack/core/util/save_restore_utility_impl.hpp | 143 +++++++++++++++++----
 3 files changed, 162 insertions(+), 68 deletions(-)

diff --git a/src/mlpack/core/util/save_restore_utility.cpp b/src/mlpack/core/util/save_restore_utility.cpp
index b294773..7eeaf6d 100644
--- a/src/mlpack/core/util/save_restore_utility.cpp
+++ b/src/mlpack/core/util/save_restore_utility.cpp
@@ -79,22 +79,6 @@ void SaveRestoreUtility::WriteFile(xmlNode* n)
   }
 }
 
-arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
-                                             const std::string& name) const
-{
-  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
-  if (it != parameters.end())
-  {
-    std::istringstream input((*it).second);
-    matrix.load(input);
-  }
-  else
-  {
-    Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
-  }
-  return matrix;
-}
-
 std::string SaveRestoreUtility::LoadParameter(std::string& str,
                                               const std::string& name) const
 {
@@ -136,24 +120,8 @@ void SaveRestoreUtility::SaveParameter(const char c, const std::string& name)
   parameters[name] = output.str();
 }
 
-// Special template specializations for vectors.
-
-namespace mlpack {
-namespace util {
-
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
-                                             const std::string& name) const
-{
-  return (arma::vec&) LoadParameter((arma::mat&) t, name);
-}
-
 void SaveRestoreUtility::AddChild(SaveRestoreUtility& mn, const std::string&
     name)
 {
   children[name] = mn;
 }
-
-
-}; // namespace util
-}; // namespace mlpack
diff --git a/src/mlpack/core/util/save_restore_utility.hpp b/src/mlpack/core/util/save_restore_utility.hpp
index 9a30822..9536437 100644
--- a/src/mlpack/core/util/save_restore_utility.hpp
+++ b/src/mlpack/core/util/save_restore_utility.hpp
@@ -59,10 +59,16 @@ class SaveRestoreUtility
   bool WriteFile(const std::string& filename);
 
   /**
-   * LoadParameter loads a parameter from the parameters map.
+   * LoadParameter loads a parameter from the parameters map.  This overload is
+   * not called for Armadillo objects (via the enable_if).
    */
   template<typename T>
-  T& LoadParameter(T& t, const std::string& name) const;
+  T& LoadParameter(T& t,
+                   const std::string& name,
+                   const typename boost::enable_if_c<
+                       (!arma::is_arma_type<T>::value &&
+                        !arma::is_arma_sparse_type<T>::value)
+                       >::type* junk = 0) const;
 
   /**
    * LoadParameter loads a parameter from the parameters map.
@@ -82,15 +88,42 @@ class SaveRestoreUtility
   std::string LoadParameter(std::string& str, const std::string& name) const;
 
   /**
-   * LoadParameter loads an arma::mat from the parameters map.
+   * LoadParameter loads an Armadillo matrix from the parameters map.
    */
-  arma::mat& LoadParameter(arma::mat& matrix, const std::string& name) const;
+  template<typename eT>
+  arma::Mat<eT>& LoadParameter(arma::Mat<eT>& matrix, const std::string& name)
+      const;
 
   /**
-   * SaveParameter saves a parameter to the parameters map.
+   * LoadParameter loads an Armadillo sparse matrix from the parameters map.
+   */
+  template<typename eT>
+  arma::SpMat<eT>& LoadParameter(arma::SpMat<eT>& matrix,
+                                 const std::string& name) const;
+
+  /**
+   * SaveParameter saves a dense Armadillo object to the parameters map.
+   */
+  template<typename eT, typename T1>
+  void SaveParameter(const arma::Base<eT, T1>& t, const std::string& name);
+
+  /**
+   * SaveParameter saves a sparse Armadillo object to the parameters map.
+   */
+  template<typename eT, typename T1>
+  void SaveParameter(const arma::SpBase<eT, T1>& t, const std::string& name);
+
+  /**
+   * SaveParameter saves a parameter to the parameters map.  This is not called
+   * for Armadillo objects, via the enable_if.
    */
   template<typename T>
-  void SaveParameter(const T& t, const std::string& name);
+  void SaveParameter(const T& t,
+                     const std::string& name,
+                     const typename boost::enable_if_c<
+                         (!arma::is_arma_type<T>::value &&
+                          !arma::is_arma_sparse_type<T>::value)
+                         >::type* junk = 0);
 
   /**
    * SaveParameter saves a parameter to the parameters map.
@@ -131,14 +164,8 @@ class SaveRestoreUtility
   void ReadFile(xmlNode* n);
 };
 
-//! Specialization for arma::vec.
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
-                                             const std::string& name) const;
-
-
-}; /* namespace util */
-}; /* namespace mlpack */
+} /* namespace util */
+} /* namespace mlpack */
 
 // Include implementation.
 #include "save_restore_utility_impl.hpp"
diff --git a/src/mlpack/core/util/save_restore_utility_impl.hpp b/src/mlpack/core/util/save_restore_utility_impl.hpp
index e21c5f4..44a5d95 100644
--- a/src/mlpack/core/util/save_restore_utility_impl.hpp
+++ b/src/mlpack/core/util/save_restore_utility_impl.hpp
@@ -16,15 +16,52 @@ namespace mlpack {
 namespace util {
 
 template<typename T>
-T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name) const
+std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
+                                                  const std::string& name) const
 {
   std::map<std::string, std::string>::const_iterator it = parameters.find(name);
   if (it != parameters.end())
   {
+    v.clear();
     std::string value = (*it).second;
-    std::istringstream input (value);
-    input >> t;
-    return t;
+    boost::char_separator<char> sep (",");
+    boost::tokenizer<boost::char_separator<char> > tok (value, sep);
+    std::list<std::list<double> > rows;
+    for (boost::tokenizer<boost::char_separator<char> >::iterator
+        tokIt = tok.begin(); tokIt != tok.end(); ++tokIt)
+    {
+      T t;
+      std::istringstream iss(*tokIt);
+      iss >> t;
+      v.push_back(t);
+    }
+  }
+  else
+  {
+    Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+  }
+  return v;
+}
+
+// Load Armadillo matrices specially, in order to preserve precision.  This
+// catches dense objects.
+template<typename eT>
+arma::Mat<eT>& SaveRestoreUtility::LoadParameter(
+    arma::Mat<eT>& t,
+    const std::string& name) const
+{
+  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
+  if (it != parameters.end())
+  {
+    std::string value = (*it).second;
+    std::istringstream input(value);
+
+    std::string err; // Store a possible error message.
+    if (!arma::diskio::load_csv_ascii(t, input, err))
+    {
+      Log::Fatal << "LoadParameter(): error while loading node '" << name
+          << "': " << err << ".\n";
+    }
   }
   else
   {
@@ -33,38 +70,101 @@ T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name) const
   return t;
 }
 
-template<typename T>
-std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
-                                                  const std::string& name) const
+// Load Armadillo matrices specially, in order to preserve precision.  This
+// catches sparse objects.
+template<typename eT>
+arma::SpMat<eT>& SaveRestoreUtility::LoadParameter(
+    arma::SpMat<eT>& t,
+    const std::string& name) const
 {
   std::map<std::string, std::string>::const_iterator it = parameters.find(name);
   if (it != parameters.end())
   {
-    v.clear();
     std::string value = (*it).second;
-    boost::char_separator<char> sep (",");
-    boost::tokenizer<boost::char_separator<char> > tok (value, sep);
-    std::list<std::list<double> > rows;
-    for (boost::tokenizer<boost::char_separator<char> >::iterator
-           tokIt = tok.begin();
-         tokIt != tok.end();
-         ++tokIt)
+    std::istringstream input(value);
+
+    std::string err; // Store a possible error message.
+    if (!arma::diskio::load_coord_ascii(t, input, err))
     {
-      T t;
-      std::istringstream iss (*tokIt);
-      iss >> t;
-      v.push_back(t);
+      Log::Fatal << "LoadParameter(): error while loading node '" << name
+          << "': " << err << ".\n";
     }
   }
   else
   {
     Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
   }
-  return v;
+  return t;
+}
+
+template<typename T>
+T& SaveRestoreUtility::LoadParameter(
+    T& t,
+    const std::string& name,
+    const typename boost::enable_if_c<(!arma::is_arma_type<T>::value &&
+                                       !arma::is_arma_sparse_type<T>::value)
+                                     >::type* /* junk */) const
+{
+  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
+  if (it != parameters.end())
+  {
+    std::string value = (*it).second;
+    std::istringstream input(value);
+    input >> t;
+    return t;
+  }
+  else
+  {
+    Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+  }
+  return t;
+}
+
+// Print Armadillo matrices specially, in order to preserve precision.  This
+// catches dense objects.
+template<typename eT, typename T1>
+void SaveRestoreUtility::SaveParameter(
+    const arma::Base<eT, T1>& t,
+    const std::string& name)
+{
+  // Create a matrix to give to save_csv_ascii().  This may incur a copy,
+  // depending on the compiler's intelligence.  But the disk bandwidth is going
+  // to be the main slowdown anyway...
+  arma::Mat<eT> temp(t.get_ref());
+
+  // Use save_csv_ascii().  This is *slightly* imprecise and it may be better to
+  // store this raw.  But this is readable...
+  std::ostringstream output;
+  arma::diskio::save_csv_ascii(temp, output);
+  parameters[name] = output.str();
+}
+
+// Print sparse Armadillo matrices specially, in order to preserve precision.
+// This catches sparse objects.
+template<typename eT, typename T1>
+void SaveRestoreUtility::SaveParameter(
+    const arma::SpBase<eT, T1>& t,
+    const std::string& name)
+{
+  // Create a matrix to give to save_coord_ascii().  This may incur a copy,
+  // depending on the compiler's intelligence.  But the disk bandwidth is going
+  // to be the main slowdown anyway...
+  arma::SpMat<eT> temp(t.get_ref());
+
+  // Use save_coord_ascii().  This is *slightly* imprecise and it may be better
+  // to store this raw.  But this is readable...
+  std::ostringstream output;
+  arma::diskio::save_coord_ascii(temp, output);
+  parameters[name] = output.str();
 }
 
 template<typename T>
-void SaveRestoreUtility::SaveParameter(const T& t, const std::string& name)
+void SaveRestoreUtility::SaveParameter(
+    const T& t,
+    const std::string& name,
+    const typename boost::enable_if_c<(!arma::is_arma_type<T>::value &&
+                                       !arma::is_arma_sparse_type<T>::value)
+                                     >::type* /* junk */)
 {
   std::ostringstream output;
   // Manually increase precision to solve #313 for now, until we have a way to
@@ -87,7 +187,6 @@ void SaveRestoreUtility::SaveParameter(const std::vector<T>& t,
   parameters[name] = vectorAsStr;
 }
 
-    
 }; // namespace util
 }; // namespace mlpack
 



More information about the mlpack-git mailing list