[mlpack-git] master: Add serialization for arma::mat. (22d6906)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:18 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 22d69065cebadd517772b118e816fa7135ddfcf4
Author: ryan <ryan at ratml.org>
Date:   Thu Apr 16 20:47:49 2015 -0400

    Add serialization for arma::mat.


>---------------------------------------------------------------

22d69065cebadd517772b118e816fa7135ddfcf4
 src/mlpack/core/arma_extend/Mat_extra_bones.hpp |   4 +
 src/mlpack/core/arma_extend/Mat_extra_meat.hpp  |  27 ++++
 src/mlpack/core/arma_extend/arma_extend.hpp     |   5 +
 src/mlpack/tests/CMakeLists.txt                 |   1 +
 src/mlpack/tests/serialization_test.cpp         | 159 ++++++++++++++++++++++++
 5 files changed, 196 insertions(+)

diff --git a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
index 45b25d5..433d3d4 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
@@ -1,3 +1,7 @@
+//! Add a serialization operator.
+template<typename Archive>
+void serialize(Archive& ar, const unsigned int version);
+
 /*
  * Add row_col_iterator and row_col_const_iterator to arma::Mat.
  */
diff --git a/src/mlpack/core/arma_extend/Mat_extra_meat.hpp b/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
index fe7ba1b..4e2700d 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_meat.hpp
@@ -1,3 +1,30 @@
+// Add a serialization operator.
+template<typename eT>
+template<typename Archive>
+void Mat<eT>::serialize(Archive& ar, const unsigned int /* version */)
+{
+  using boost::serialization::make_nvp;
+  using boost::serialization::make_array;
+
+  // This is accurate from Armadillo 3.6.0 onwards.
+  // We can't use BOOST_SERIALIZATION_NVP() because of the access::rw() call.
+  ar & make_nvp("n_rows", access::rw(n_rows));
+  ar & make_nvp("n_cols", access::rw(n_cols));
+  ar & make_nvp("n_elem", access::rw(n_elem));
+  ar & make_nvp("vec_state", access::rw(vec_state));
+
+  // mem_state will always be 0 on load, so we don't need to save it.
+  if (Archive::is_loading::value)
+  {
+    access::rw(mem_state) = 0;
+
+    // We also need to allocate the memory we're using.
+    init_cold();
+  }
+
+  ar & make_array(access::rwp(mem), n_elem);
+}
+
 #if ARMA_VERSION_MAJOR < 4 || \
     (ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
 ///////////////////////////////////////////////////////////////////////////////
diff --git a/src/mlpack/core/arma_extend/arma_extend.hpp b/src/mlpack/core/arma_extend/arma_extend.hpp
index ec50735..8e0ccdd 100644
--- a/src/mlpack/core/arma_extend/arma_extend.hpp
+++ b/src/mlpack/core/arma_extend/arma_extend.hpp
@@ -25,6 +25,11 @@
   #define ARMA_USE_U64S64
 #endif
 
+// Include everything we'll need for serialize().
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/nvp.hpp>
+#include <boost/serialization/array.hpp>
+
 #include <armadillo>
 
 namespace arma {
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 0caa3bb..c0566f8 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -48,6 +48,7 @@ add_executable(mlpack_test
   save_restore_utility_test.cpp
   sdp_primal_dual_test.cpp
   sgd_test.cpp
+  serialization_test.cpp
   softmax_regression_test.cpp
   sort_policy_test.cpp
   sparse_autoencoder_test.cpp
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
new file mode 100644
index 0000000..fa1dea9
--- /dev/null
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -0,0 +1,159 @@
+/**
+ * @file serialization_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test serialization of mlpack objects.
+ */
+#include <boost/serialization/serialization.hpp>
+#include <boost/archive/xml_iarchive.hpp>
+#include <boost/archive/xml_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <mlpack/core.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace arma;
+using namespace boost;
+using namespace boost::archive;
+using namespace boost::serialization;
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(SerializationTest);
+
+// Test function for loading and saving Armadillo objects.
+template<typename MatType,
+         typename IArchiveType,
+         typename OArchiveType>
+void TestArmadilloSerialization(MatType& x)
+{
+  // First save it.
+  ofstream ofs("test");
+  OArchiveType o(ofs);
+
+  bool success = true;
+  try
+  {
+    o << BOOST_SERIALIZATION_NVP(x);
+  }
+  catch (archive_exception& e)
+  {
+    success = false;
+  }
+
+  BOOST_REQUIRE_EQUAL(success, true);
+  ofs.close();
+
+  // Now load it.
+  MatType orig(x);
+  success = true;
+  ifstream ifs("test");
+  IArchiveType i(ifs);
+
+  try
+  {
+    i >> BOOST_SERIALIZATION_NVP(x);
+  }
+  catch (archive_exception& e)
+  {
+    success = false;
+  }
+
+  BOOST_REQUIRE_EQUAL(success, true);
+
+  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
+  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
+  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
+
+  for (size_t i = 0; i < x.n_cols; ++i)
+    for (size_t j = 0; j < x.n_rows; ++j)
+      if (orig(j, i) == 0.0)
+        BOOST_REQUIRE_SMALL(x(j, i), 1e-8);
+      else
+        BOOST_REQUIRE_CLOSE(orig(j, i), x(j, i), 1e-8);
+}
+
+/**
+ * Can we load and save an Armadillo matrix from XML?
+ */
+BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
+{
+  arma::mat m;
+  m.randu(50, 50);
+  TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(MatrixSerializeTextTest)
+{
+  arma::mat m;
+  m.randu(50, 50);
+  TestArmadilloSerialization<arma::mat, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(MatrixSerializeBinaryTest)
+{
+  arma::mat m;
+  m.randu(50, 50);
+  TestArmadilloSerialization<arma::mat, binary_iarchive, binary_oarchive>(m);
+}
+
+/**
+ * How about columns?
+ */
+BOOST_AUTO_TEST_CASE(ColSerializeXMLTest)
+{
+  arma::vec m;
+  m.randu(50, 1);
+  TestArmadilloSerialization<arma::vec, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(ColSerializeTextTest)
+{
+  arma::vec m;
+  m.randu(50, 1);
+  TestArmadilloSerialization<arma::vec, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(ColSerializeBinaryTest)
+{
+  arma::vec m;
+  m.randu(50, 1);
+  TestArmadilloSerialization<arma::vec, binary_iarchive, binary_oarchive>(m);
+}
+
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
+{
+  arma::rowvec m;
+  m.randu(1, 50);
+  TestArmadilloSerialization<arma::rowvec, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(RowSerializeTextTest)
+{
+  arma::rowvec m;
+  m.randu(1, 50);
+  TestArmadilloSerialization<arma::rowvec, text_iarchive, text_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_CASE(RowSerializeBinaryTest)
+{
+  arma::rowvec m;
+  m.randu(1, 50);
+  TestArmadilloSerialization<arma::rowvec, binary_iarchive, binary_oarchive>(m);
+}
+
+// A quick test with an empty matrix.
+BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
+{
+  arma::mat m;
+  TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+}
+
+BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list