[mlpack-git] master: Add serialization for sparse matrices. (7686349)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:20 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 76863492a8cf58be6b227f12b503d7b80888644f
Author: ryan <ryan at ratml.org>
Date:   Fri Apr 17 00:50:13 2015 -0400

    Add serialization for sparse matrices.


>---------------------------------------------------------------

76863492a8cf58be6b227f12b503d7b80888644f
 src/mlpack/core/arma_extend/SpMat_extra_bones.hpp |  6 +-
 src/mlpack/core/arma_extend/SpMat_extra_meat.hpp  | 34 ++++++++-
 src/mlpack/tests/serialization_test.cpp           | 92 ++++++++++++-----------
 3 files changed, 86 insertions(+), 46 deletions(-)

diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index 51b724c..6a84ddc 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -2,8 +2,12 @@
  * @file SpMat_extra_bones.hpp
  * @author Ryan Curtin
  *
- * Add a batch constructor for SpMat, if the version is older than 3.810.0.
+ * Add a batch constructor for SpMat, if the version is older than 3.810.0, and
+ * also a serialize() function for Armadillo.
  */
+template<typename Archive>
+void serialize(Archive& ar, const unsigned int version);
+
 #if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
 template<typename T1, typename T2>
 inline SpMat(
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
index d2ad10e..341aec4 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
@@ -3,8 +3,40 @@
  * @author Ryan Curtin
  *
  * Take the Armadillo batch sparse matrix constructor function from newer
- * Armadillo versions and port it to versions earlier than 3.810.0.
+ * Armadillo versions and port it to versions earlier than 3.810.0, and also add
+ * a serialization function.
  */
+template<typename eT>
+template<typename Archive>
+void SpMat<eT>::serialize(Archive& ar, const unsigned int /* version */)
+{
+  using boost::serialization::make_nvp;
+  using boost::serialization::make_array;
+
+  // This is accurate from Armadillo 3.6.0 onwards.
+  // We can't use BOOST_SERIALIZATION_NVP() because of the access::rw() call.
+  ar & make_nvp("n_rows", access::rw(n_rows));
+  ar & make_nvp("n_cols", access::rw(n_cols));
+  ar & make_nvp("n_elem", access::rw(n_elem));
+  ar & make_nvp("n_nonzero", access::rw(n_nonzero));
+  ar & make_nvp("vec_state", access::rw(vec_state));
+
+  // Now we have to serialize the values, row indices, and column pointers.
+  // If we are loading, we need to initialize space for these things.
+  if (Archive::is_loading::value)
+  {
+    const uword new_n_nonzero = n_nonzero; // Save this; we're about to nuke it.
+    init(n_rows, n_cols); // Allocate column pointers.
+    mem_resize(new_n_nonzero); // Allocate storage.
+    // These calls will set the sentinel values at the end of the storage and
+    // column pointers, if necessary, so we don't need to worry about them.
+  }
+
+  ar & make_array(access::rwp(values), n_nonzero);
+  ar & make_array(access::rwp(row_indices), n_nonzero);
+  ar & make_array(access::rwp(col_ptrs), n_cols + 1);
+}
+
 #if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
 
 //! Insert a large number of values at once.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index fa1dea9..878f5f6 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -71,34 +71,29 @@ void TestArmadilloSerialization(MatType& x)
 
   for (size_t i = 0; i < x.n_cols; ++i)
     for (size_t j = 0; j < x.n_rows; ++j)
-      if (orig(j, i) == 0.0)
-        BOOST_REQUIRE_SMALL(x(j, i), 1e-8);
+      if (double(orig(j, i)) == 0.0)
+        BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
       else
-        BOOST_REQUIRE_CLOSE(orig(j, i), x(j, i), 1e-8);
+        BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
 }
 
-/**
- * Can we load and save an Armadillo matrix from XML?
- */
-BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
+// Test all serialization strategies.
+template<typename MatType>
+void TestAllArmadilloSerialization(MatType& x)
 {
-  arma::mat m;
-  m.randu(50, 50);
-  TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+  TestArmadilloSerialization<MatType, xml_iarchive, xml_oarchive>(x);
+  TestArmadilloSerialization<MatType, text_iarchive, text_oarchive>(x);
+  TestArmadilloSerialization<MatType, binary_iarchive, binary_oarchive>(x);
 }
 
-BOOST_AUTO_TEST_CASE(MatrixSerializeTextTest)
-{
-  arma::mat m;
-  m.randu(50, 50);
-  TestArmadilloSerialization<arma::mat, text_iarchive, text_oarchive>(m);
-}
-
-BOOST_AUTO_TEST_CASE(MatrixSerializeBinaryTest)
+/**
+ * Can we load and save an Armadillo matrix?
+ */
+BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
 {
   arma::mat m;
   m.randu(50, 50);
-  TestArmadilloSerialization<arma::mat, binary_iarchive, binary_oarchive>(m);
+  TestAllArmadilloSerialization(m);
 }
 
 /**
@@ -108,52 +103,61 @@ BOOST_AUTO_TEST_CASE(ColSerializeXMLTest)
 {
   arma::vec m;
   m.randu(50, 1);
-  TestArmadilloSerialization<arma::vec, xml_iarchive, xml_oarchive>(m);
+  TestAllArmadilloSerialization(m);
 }
 
-BOOST_AUTO_TEST_CASE(ColSerializeTextTest)
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
 {
-  arma::vec m;
-  m.randu(50, 1);
-  TestArmadilloSerialization<arma::vec, text_iarchive, text_oarchive>(m);
+  arma::rowvec m;
+  m.randu(1, 50);
+  TestAllArmadilloSerialization(m);
 }
 
-BOOST_AUTO_TEST_CASE(ColSerializeBinaryTest)
+// A quick test with an empty matrix.
+BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
 {
-  arma::vec m;
-  m.randu(50, 1);
-  TestArmadilloSerialization<arma::vec, binary_iarchive, binary_oarchive>(m);
+  arma::mat m;
+  TestAllArmadilloSerialization(m);
 }
 
 /**
- * How about rows?
+ * Can we load and save a sparse Armadillo matrix?
  */
-BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
+BOOST_AUTO_TEST_CASE(SparseMatrixSerializeXMLTest)
 {
-  arma::rowvec m;
-  m.randu(1, 50);
-  TestArmadilloSerialization<arma::rowvec, xml_iarchive, xml_oarchive>(m);
+  arma::sp_mat m;
+  m.sprandu(50, 50, 0.3);
+  TestAllArmadilloSerialization(m);
 }
 
-BOOST_AUTO_TEST_CASE(RowSerializeTextTest)
+/**
+ * How about columns?
+ */
+BOOST_AUTO_TEST_CASE(SparseColSerializeXMLTest)
 {
-  arma::rowvec m;
-  m.randu(1, 50);
-  TestArmadilloSerialization<arma::rowvec, text_iarchive, text_oarchive>(m);
+  arma::sp_vec m;
+  m.sprandu(50, 1, 0.3);
+  TestAllArmadilloSerialization(m);
 }
 
-BOOST_AUTO_TEST_CASE(RowSerializeBinaryTest)
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(SparseRowSerializeXMLTest)
 {
-  arma::rowvec m;
-  m.randu(1, 50);
-  TestArmadilloSerialization<arma::rowvec, binary_iarchive, binary_oarchive>(m);
+  arma::sp_rowvec m;
+  m.sprandu(1, 50, 0.3);
+  TestAllArmadilloSerialization(m);
 }
 
 // A quick test with an empty matrix.
-BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
+BOOST_AUTO_TEST_CASE(EmptySparseMatrixSerializeTest)
 {
-  arma::mat m;
-  TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+  arma::sp_mat m;
+  TestAllArmadilloSerialization(m);
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list