[mlpack-git] master: Force high-precision Armadillo saves. This will fix the failing tests in GMMTest and HMMTest. We now use enable_if_c<> to catch when an Armadillo dense or sparse object is being saved/loaded, and use Armadillo's diskio functionality to write to (or read from) a stream. This allows high-precision saving, whereas just printing via `stream << mat` is pretty low-precision and can't be changed with setwidth(). (fd599b9)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Feb 16 15:46:07 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/c42ca0ee5228ba83123767ea8899fcf6e4817b42...fd599b9f713be002b62278bfe417f56856717083
>---------------------------------------------------------------
commit fd599b9f713be002b62278bfe417f56856717083
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Feb 16 15:44:36 2015 -0500
Force high-precision Armadillo saves.
This will fix the failing tests in GMMTest and HMMTest.
We now use enable_if_c<> to catch when an Armadillo dense or sparse object is
being saved/loaded, and use Armadillo's diskio functionality to write to (or
read from) a stream. This allows high-precision saving, whereas just printing
via `stream << mat` is pretty low-precision and can't be changed with
setwidth().
>---------------------------------------------------------------
fd599b9f713be002b62278bfe417f56856717083
src/mlpack/core/util/save_restore_utility.cpp | 32 -----
src/mlpack/core/util/save_restore_utility.hpp | 55 ++++++--
src/mlpack/core/util/save_restore_utility_impl.hpp | 143 +++++++++++++++++----
3 files changed, 162 insertions(+), 68 deletions(-)
diff --git a/src/mlpack/core/util/save_restore_utility.cpp b/src/mlpack/core/util/save_restore_utility.cpp
index b294773..7eeaf6d 100644
--- a/src/mlpack/core/util/save_restore_utility.cpp
+++ b/src/mlpack/core/util/save_restore_utility.cpp
@@ -79,22 +79,6 @@ void SaveRestoreUtility::WriteFile(xmlNode* n)
}
}
-arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
- const std::string& name) const
-{
- std::map<std::string, std::string>::const_iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- std::istringstream input((*it).second);
- matrix.load(input);
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return matrix;
-}
-
std::string SaveRestoreUtility::LoadParameter(std::string& str,
const std::string& name) const
{
@@ -136,24 +120,8 @@ void SaveRestoreUtility::SaveParameter(const char c, const std::string& name)
parameters[name] = output.str();
}
-// Special template specializations for vectors.
-
-namespace mlpack {
-namespace util {
-
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
- const std::string& name) const
-{
- return (arma::vec&) LoadParameter((arma::mat&) t, name);
-}
-
void SaveRestoreUtility::AddChild(SaveRestoreUtility& mn, const std::string&
name)
{
children[name] = mn;
}
-
-
-}; // namespace util
-}; // namespace mlpack
diff --git a/src/mlpack/core/util/save_restore_utility.hpp b/src/mlpack/core/util/save_restore_utility.hpp
index 9a30822..9536437 100644
--- a/src/mlpack/core/util/save_restore_utility.hpp
+++ b/src/mlpack/core/util/save_restore_utility.hpp
@@ -59,10 +59,16 @@ class SaveRestoreUtility
bool WriteFile(const std::string& filename);
/**
- * LoadParameter loads a parameter from the parameters map.
+ * LoadParameter loads a parameter from the parameters map. This overload is
+ * not called for Armadillo objects (via the enable_if).
*/
template<typename T>
- T& LoadParameter(T& t, const std::string& name) const;
+ T& LoadParameter(T& t,
+ const std::string& name,
+ const typename boost::enable_if_c<
+ (!arma::is_arma_type<T>::value &&
+ !arma::is_arma_sparse_type<T>::value)
+ >::type* junk = 0) const;
/**
* LoadParameter loads a parameter from the parameters map.
@@ -82,15 +88,42 @@ class SaveRestoreUtility
std::string LoadParameter(std::string& str, const std::string& name) const;
/**
- * LoadParameter loads an arma::mat from the parameters map.
+ * LoadParameter loads an Armadillo matrix from the parameters map.
*/
- arma::mat& LoadParameter(arma::mat& matrix, const std::string& name) const;
+ template<typename eT>
+ arma::Mat<eT>& LoadParameter(arma::Mat<eT>& matrix, const std::string& name)
+ const;
/**
- * SaveParameter saves a parameter to the parameters map.
+ * LoadParameter loads an Armadillo sparse matrix from the parameters map.
+ */
+ template<typename eT>
+ arma::SpMat<eT>& LoadParameter(arma::SpMat<eT>& matrix,
+ const std::string& name) const;
+
+ /**
+ * SaveParameter saves a dense Armadillo object to the parameters map.
+ */
+ template<typename eT, typename T1>
+ void SaveParameter(const arma::Base<eT, T1>& t, const std::string& name);
+
+ /**
+ * SaveParameter saves a sparse Armadillo object to the parameters map.
+ */
+ template<typename eT, typename T1>
+ void SaveParameter(const arma::SpBase<eT, T1>& t, const std::string& name);
+
+ /**
+ * SaveParameter saves a parameter to the parameters map. This is not called
+ * for Armadillo objects, via the enable_if.
*/
template<typename T>
- void SaveParameter(const T& t, const std::string& name);
+ void SaveParameter(const T& t,
+ const std::string& name,
+ const typename boost::enable_if_c<
+ (!arma::is_arma_type<T>::value &&
+ !arma::is_arma_sparse_type<T>::value)
+ >::type* junk = 0);
/**
* SaveParameter saves a parameter to the parameters map.
@@ -131,14 +164,8 @@ class SaveRestoreUtility
void ReadFile(xmlNode* n);
};
-//! Specialization for arma::vec.
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
- const std::string& name) const;
-
-
-}; /* namespace util */
-}; /* namespace mlpack */
+} /* namespace util */
+} /* namespace mlpack */
// Include implementation.
#include "save_restore_utility_impl.hpp"
diff --git a/src/mlpack/core/util/save_restore_utility_impl.hpp b/src/mlpack/core/util/save_restore_utility_impl.hpp
index e21c5f4..44a5d95 100644
--- a/src/mlpack/core/util/save_restore_utility_impl.hpp
+++ b/src/mlpack/core/util/save_restore_utility_impl.hpp
@@ -16,15 +16,52 @@ namespace mlpack {
namespace util {
template<typename T>
-T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name) const
+std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
+ const std::string& name) const
{
std::map<std::string, std::string>::const_iterator it = parameters.find(name);
if (it != parameters.end())
{
+ v.clear();
std::string value = (*it).second;
- std::istringstream input (value);
- input >> t;
- return t;
+ boost::char_separator<char> sep (",");
+ boost::tokenizer<boost::char_separator<char> > tok (value, sep);
+ std::list<std::list<double> > rows;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokIt = tok.begin(); tokIt != tok.end(); ++tokIt)
+ {
+ T t;
+ std::istringstream iss(*tokIt);
+ iss >> t;
+ v.push_back(t);
+ }
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return v;
+}
+
+// Load Armadillo matrices specially, in order to preserve precision. This
+// catches dense objects.
+template<typename eT>
+arma::Mat<eT>& SaveRestoreUtility::LoadParameter(
+ arma::Mat<eT>& t,
+ const std::string& name) const
+{
+ std::map<std::string, std::string>::const_iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ std::string value = (*it).second;
+ std::istringstream input(value);
+
+ std::string err; // Store a possible error message.
+ if (!arma::diskio::load_csv_ascii(t, input, err))
+ {
+ Log::Fatal << "LoadParameter(): error while loading node '" << name
+ << "': " << err << ".\n";
+ }
}
else
{
@@ -33,38 +70,101 @@ T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name) const
return t;
}
-template<typename T>
-std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
- const std::string& name) const
+// Load Armadillo matrices specially, in order to preserve precision. This
+// catches sparse objects.
+template<typename eT>
+arma::SpMat<eT>& SaveRestoreUtility::LoadParameter(
+ arma::SpMat<eT>& t,
+ const std::string& name) const
{
std::map<std::string, std::string>::const_iterator it = parameters.find(name);
if (it != parameters.end())
{
- v.clear();
std::string value = (*it).second;
- boost::char_separator<char> sep (",");
- boost::tokenizer<boost::char_separator<char> > tok (value, sep);
- std::list<std::list<double> > rows;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokIt = tok.begin();
- tokIt != tok.end();
- ++tokIt)
+ std::istringstream input(value);
+
+ std::string err; // Store a possible error message.
+ if (!arma::diskio::load_coord_ascii(t, input, err))
{
- T t;
- std::istringstream iss (*tokIt);
- iss >> t;
- v.push_back(t);
+ Log::Fatal << "LoadParameter(): error while loading node '" << name
+ << "': " << err << ".\n";
}
}
else
{
Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
}
- return v;
+ return t;
+}
+
+template<typename T>
+T& SaveRestoreUtility::LoadParameter(
+ T& t,
+ const std::string& name,
+ const typename boost::enable_if_c<(!arma::is_arma_type<T>::value &&
+ !arma::is_arma_sparse_type<T>::value)
+ >::type* /* junk */) const
+{
+ std::map<std::string, std::string>::const_iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ std::string value = (*it).second;
+ std::istringstream input(value);
+ input >> t;
+ return t;
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return t;
+}
+
+// Print Armadillo matrices specially, in order to preserve precision. This
+// catches dense objects.
+template<typename eT, typename T1>
+void SaveRestoreUtility::SaveParameter(
+ const arma::Base<eT, T1>& t,
+ const std::string& name)
+{
+ // Create a matrix to give to save_csv_ascii(). This may incur a copy,
+ // depending on the compiler's intelligence. But the disk bandwidth is going
+ // to be the main slowdown anyway...
+ arma::Mat<eT> temp(t.get_ref());
+
+ // Use save_csv_ascii(). This is *slightly* imprecise and it may be better to
+ // store this raw. But this is readable...
+ std::ostringstream output;
+ arma::diskio::save_csv_ascii(temp, output);
+ parameters[name] = output.str();
+}
+
+// Print sparse Armadillo matrices specially, in order to preserve precision.
+// This catches sparse objects.
+template<typename eT, typename T1>
+void SaveRestoreUtility::SaveParameter(
+ const arma::SpBase<eT, T1>& t,
+ const std::string& name)
+{
+ // Create a matrix to give to save_coord_ascii(). This may incur a copy,
+ // depending on the compiler's intelligence. But the disk bandwidth is going
+ // to be the main slowdown anyway...
+ arma::SpMat<eT> temp(t.get_ref());
+
+ // Use save_coord_ascii(). This is *slightly* imprecise and it may be better
+ // to store this raw. But this is readable...
+ std::ostringstream output;
+ arma::diskio::save_coord_ascii(temp, output);
+ parameters[name] = output.str();
}
template<typename T>
-void SaveRestoreUtility::SaveParameter(const T& t, const std::string& name)
+void SaveRestoreUtility::SaveParameter(
+ const T& t,
+ const std::string& name,
+ const typename boost::enable_if_c<(!arma::is_arma_type<T>::value &&
+ !arma::is_arma_sparse_type<T>::value)
+ >::type* /* junk */)
{
std::ostringstream output;
// Manually increase precision to solve #313 for now, until we have a way to
@@ -87,7 +187,6 @@ void SaveRestoreUtility::SaveParameter(const std::vector<T>& t,
parameters[name] = vectorAsStr;
}
-
}; // namespace util
}; // namespace mlpack
More information about the mlpack-git
mailing list