[mlpack-git] master: Add serialization for arma::mat. (22d6906)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:18 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 22d69065cebadd517772b118e816fa7135ddfcf4
Author: ryan <ryan at ratml.org>
Date: Thu Apr 16 20:47:49 2015 -0400
Add serialization for arma::mat.
>---------------------------------------------------------------
22d69065cebadd517772b118e816fa7135ddfcf4
src/mlpack/core/arma_extend/Mat_extra_bones.hpp | 4 +
src/mlpack/core/arma_extend/Mat_extra_meat.hpp | 27 ++++
src/mlpack/core/arma_extend/arma_extend.hpp | 5 +
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/serialization_test.cpp | 159 ++++++++++++++++++++++++
5 files changed, 196 insertions(+)
diff --git a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
index 45b25d5..433d3d4 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
@@ -1,3 +1,7 @@
+//! Add a serialization operator.
+template<typename Archive>
+void serialize(Archive& ar, const unsigned int version);
+
/*
* Add row_col_iterator and row_col_const_iterator to arma::Mat.
*/
diff --git a/src/mlpack/core/arma_extend/Mat_extra_meat.hpp b/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
index fe7ba1b..4e2700d 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
@@ -1,3 +1,30 @@
+// Add a serialization operator.
+template<typename eT>
+template<typename Archive>
+void Mat<eT>::serialize(Archive& ar, const unsigned int /* version */)
+{
+ using boost::serialization::make_nvp;
+ using boost::serialization::make_array;
+
+ // This is accurate from Armadillo 3.6.0 onwards.
+ // We can't use BOOST_SERIALIZATION_NVP() because of the access::rw() call.
+ ar & make_nvp("n_rows", access::rw(n_rows));
+ ar & make_nvp("n_cols", access::rw(n_cols));
+ ar & make_nvp("n_elem", access::rw(n_elem));
+ ar & make_nvp("vec_state", access::rw(vec_state));
+
+ // mem_state will always be 0 on load, so we don't need to save it.
+ if (Archive::is_loading::value)
+ {
+ access::rw(mem_state) = 0;
+
+ // We also need to allocate the memory we're using.
+ init_cold();
+ }
+
+ ar & make_array(access::rwp(mem), n_elem);
+}
+
#if ARMA_VERSION_MAJOR < 4 || \
(ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
///////////////////////////////////////////////////////////////////////////////
diff --git a/src/mlpack/core/arma_extend/arma_extend.hpp b/src/mlpack/core/arma_extend/arma_extend.hpp
index ec50735..8e0ccdd 100644
--- a/src/mlpack/core/arma_extend/arma_extend.hpp
+++ b/src/mlpack/core/arma_extend/arma_extend.hpp
@@ -25,6 +25,11 @@
#define ARMA_USE_U64S64
#endif
+// Include everything we'll need for serialize().
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/nvp.hpp>
+#include <boost/serialization/array.hpp>
+
#include <armadillo>
namespace arma {
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 0caa3bb..c0566f8 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -48,6 +48,7 @@ add_executable(mlpack_test
save_restore_utility_test.cpp
sdp_primal_dual_test.cpp
sgd_test.cpp
+ serialization_test.cpp
softmax_regression_test.cpp
sort_policy_test.cpp
sparse_autoencoder_test.cpp
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
new file mode 100644
index 0000000..fa1dea9
--- /dev/null
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -0,0 +1,159 @@
+/**
+ * @file serialization_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test serialization of mlpack objects.
+ */
+#include <boost/serialization/serialization.hpp>
+#include <boost/archive/xml_iarchive.hpp>
+#include <boost/archive/xml_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <mlpack/core.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace arma;
+using namespace boost;
+using namespace boost::archive;
+using namespace boost::serialization;
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(SerializationTest);
+
+// Test function for loading and saving Armadillo objects.
+template<typename MatType,
+ typename IArchiveType,
+ typename OArchiveType>
+void TestArmadilloSerialization(MatType& x)
+{
+ // First save it.
+ ofstream ofs("test");
+ OArchiveType o(ofs);
+
+ bool success = true;
+ try
+ {
+ o << BOOST_SERIALIZATION_NVP(x);
+ }
+ catch (archive_exception& e)
+ {
+ success = false;
+ }
+
+ BOOST_REQUIRE_EQUAL(success, true);
+ ofs.close();
+
+ // Now load it.
+ MatType orig(x);
+ success = true;
+ ifstream ifs("test");
+ IArchiveType i(ifs);
+
+ try
+ {
+ i >> BOOST_SERIALIZATION_NVP(x);
+ }
+ catch (archive_exception& e)
+ {
+ success = false;
+ }
+
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
+
+ for (size_t i = 0; i < x.n_cols; ++i)
+ for (size_t j = 0; j < x.n_rows; ++j)
+ if (orig(j, i) == 0.0)
+ BOOST_REQUIRE_SMALL(x(j, i), 1e-8);
+ else
+ BOOST_REQUIRE_CLOSE(orig(j, i), x(j, i), 1e-8);
+}
+
+/**
+ * Can we load and save an Armadillo matrix from XML?
+ */
+BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
+{
+ arma::mat m;
+ m.randu(50, 50);
+ TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(MatrixSerializeTextTest)
+{
+ arma::mat m;
+ m.randu(50, 50);
+ TestArmadilloSerialization<arma::mat, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(MatrixSerializeBinaryTest)
+{
+ arma::mat m;
+ m.randu(50, 50);
+ TestArmadilloSerialization<arma::mat, binary_iarchive, binary_oarchive>(m);
+}
+
+/**
+ * How about columns?
+ */
+BOOST_AUTO_TEST_CASE(ColSerializeXMLTest)
+{
+ arma::vec m;
+ m.randu(50, 1);
+ TestArmadilloSerialization<arma::vec, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(ColSerializeTextTest)
+{
+ arma::vec m;
+ m.randu(50, 1);
+ TestArmadilloSerialization<arma::vec, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(ColSerializeBinaryTest)
+{
+ arma::vec m;
+ m.randu(50, 1);
+ TestArmadilloSerialization<arma::vec, binary_iarchive, binary_oarchive>(m);
+}
+
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
+{
+ arma::rowvec m;
+ m.randu(1, 50);
+ TestArmadilloSerialization<arma::rowvec, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(RowSerializeTextTest)
+{
+ arma::rowvec m;
+ m.randu(1, 50);
+ TestArmadilloSerialization<arma::rowvec, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(RowSerializeBinaryTest)
+{
+ arma::rowvec m;
+ m.randu(1, 50);
+ TestArmadilloSerialization<arma::rowvec, binary_iarchive, binary_oarchive>(m);
+}
+
+// A quick test with an empty matrix.
+BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
+{
+ arma::mat m;
+ TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list