[mlpack-git] master: Add serialization for sparse matrices. (7686349)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:20 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 76863492a8cf58be6b227f12b503d7b80888644f
Author: ryan <ryan at ratml.org>
Date: Fri Apr 17 00:50:13 2015 -0400
Add serialization for sparse matrices.
>---------------------------------------------------------------
76863492a8cf58be6b227f12b503d7b80888644f
src/mlpack/core/arma_extend/SpMat_extra_bones.hpp | 6 +-
src/mlpack/core/arma_extend/SpMat_extra_meat.hpp | 34 ++++++++-
src/mlpack/tests/serialization_test.cpp | 92 ++++++++++++-----------
3 files changed, 86 insertions(+), 46 deletions(-)
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index 51b724c..6a84ddc 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -2,8 +2,12 @@
* @file SpMat_extra_bones.hpp
* @author Ryan Curtin
*
- * Add a batch constructor for SpMat, if the version is older than 3.810.0.
+ * Add a batch constructor for SpMat, if the version is older than 3.810.0, and
+ * also a serialize() function for Armadillo.
*/
+template<typename Archive>
+void serialize(Archive& ar, const unsigned int version);
+
#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
template<typename T1, typename T2>
inline SpMat(
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
index d2ad10e..341aec4 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
@@ -3,8 +3,40 @@
* @author Ryan Curtin
*
* Take the Armadillo batch sparse matrix constructor function from newer
- * Armadillo versions and port it to versions earlier than 3.810.0.
+ * Armadillo versions and port it to versions earlier than 3.810.0, and also add
+ * a serialization function.
*/
+template<typename eT>
+template<typename Archive>
+void SpMat<eT>::serialize(Archive& ar, const unsigned int /* version */)
+{
+ using boost::serialization::make_nvp;
+ using boost::serialization::make_array;
+
+ // This is accurate from Armadillo 3.6.0 onwards.
+ // We can't use BOOST_SERIALIZATION_NVP() because of the access::rw() call.
+ ar & make_nvp("n_rows", access::rw(n_rows));
+ ar & make_nvp("n_cols", access::rw(n_cols));
+ ar & make_nvp("n_elem", access::rw(n_elem));
+ ar & make_nvp("n_nonzero", access::rw(n_nonzero));
+ ar & make_nvp("vec_state", access::rw(vec_state));
+
+ // Now we have to serialize the values, row indices, and column pointers.
+ // If we are loading, we need to initialize space for these things.
+ if (Archive::is_loading::value)
+ {
+ const uword new_n_nonzero = n_nonzero; // Save this; we're about to nuke it.
+ init(n_rows, n_cols); // Allocate column pointers.
+ mem_resize(new_n_nonzero); // Allocate storage.
+ // These calls will set the sentinel values at the end of the storage and
+ // column pointers, if necessary, so we don't need to worry about them.
+ }
+
+ ar & make_array(access::rwp(values), n_nonzero);
+ ar & make_array(access::rwp(row_indices), n_nonzero);
+ ar & make_array(access::rwp(col_ptrs), n_cols + 1);
+}
+
#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
//! Insert a large number of values at once.
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index fa1dea9..878f5f6 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -71,34 +71,29 @@ void TestArmadilloSerialization(MatType& x)
for (size_t i = 0; i < x.n_cols; ++i)
for (size_t j = 0; j < x.n_rows; ++j)
- if (orig(j, i) == 0.0)
- BOOST_REQUIRE_SMALL(x(j, i), 1e-8);
+ if (double(orig(j, i)) == 0.0)
+ BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
else
- BOOST_REQUIRE_CLOSE(orig(j, i), x(j, i), 1e-8);
+ BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
}
-/**
- * Can we load and save an Armadillo matrix from XML?
- */
-BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
+// Test all serialization strategies.
+template<typename MatType>
+void TestAllArmadilloSerialization(MatType& x)
{
- arma::mat m;
- m.randu(50, 50);
- TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+ TestArmadilloSerialization<MatType, xml_iarchive, xml_oarchive>(x);
+ TestArmadilloSerialization<MatType, text_iarchive, text_oarchive>(x);
+ TestArmadilloSerialization<MatType, binary_iarchive, binary_oarchive>(x);
}
-BOOST_AUTO_TEST_CASE(MatrixSerializeTextTest)
-{
- arma::mat m;
- m.randu(50, 50);
- TestArmadilloSerialization<arma::mat, text_iarchive, text_oarchive>(m);
-}
-
-BOOST_AUTO_TEST_CASE(MatrixSerializeBinaryTest)
+/**
+ * Can we load and save an Armadillo matrix?
+ */
+BOOST_AUTO_TEST_CASE(MatrixSerializeXMLTest)
{
arma::mat m;
m.randu(50, 50);
- TestArmadilloSerialization<arma::mat, binary_iarchive, binary_oarchive>(m);
+ TestAllArmadilloSerialization(m);
}
/**
@@ -108,52 +103,61 @@ BOOST_AUTO_TEST_CASE(ColSerializeXMLTest)
{
arma::vec m;
m.randu(50, 1);
- TestArmadilloSerialization<arma::vec, xml_iarchive, xml_oarchive>(m);
+ TestAllArmadilloSerialization(m);
}
-BOOST_AUTO_TEST_CASE(ColSerializeTextTest)
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
{
- arma::vec m;
- m.randu(50, 1);
- TestArmadilloSerialization<arma::vec, text_iarchive, text_oarchive>(m);
+ arma::rowvec m;
+ m.randu(1, 50);
+ TestAllArmadilloSerialization(m);
}
-BOOST_AUTO_TEST_CASE(ColSerializeBinaryTest)
+// A quick test with an empty matrix.
+BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
{
- arma::vec m;
- m.randu(50, 1);
- TestArmadilloSerialization<arma::vec, binary_iarchive, binary_oarchive>(m);
+ arma::mat m;
+ TestAllArmadilloSerialization(m);
}
/**
- * How about rows?
+ * Can we load and save a sparse Armadillo matrix?
*/
-BOOST_AUTO_TEST_CASE(RowSerializeXMLTest)
+BOOST_AUTO_TEST_CASE(SparseMatrixSerializeXMLTest)
{
- arma::rowvec m;
- m.randu(1, 50);
- TestArmadilloSerialization<arma::rowvec, xml_iarchive, xml_oarchive>(m);
+ arma::sp_mat m;
+ m.sprandu(50, 50, 0.3);
+ TestAllArmadilloSerialization(m);
}
-BOOST_AUTO_TEST_CASE(RowSerializeTextTest)
+/**
+ * How about columns?
+ */
+BOOST_AUTO_TEST_CASE(SparseColSerializeXMLTest)
{
- arma::rowvec m;
- m.randu(1, 50);
- TestArmadilloSerialization<arma::rowvec, text_iarchive, text_oarchive>(m);
+ arma::sp_vec m;
+ m.sprandu(50, 1, 0.3);
+ TestAllArmadilloSerialization(m);
}
-BOOST_AUTO_TEST_CASE(RowSerializeBinaryTest)
+/**
+ * How about rows?
+ */
+BOOST_AUTO_TEST_CASE(SparseRowSerializeXMLTest)
{
- arma::rowvec m;
- m.randu(1, 50);
- TestArmadilloSerialization<arma::rowvec, binary_iarchive, binary_oarchive>(m);
+ arma::sp_rowvec m;
+ m.sprandu(1, 50, 0.3);
+ TestAllArmadilloSerialization(m);
}
// A quick test with an empty matrix.
-BOOST_AUTO_TEST_CASE(EmptyMatrixSerializeTest)
+BOOST_AUTO_TEST_CASE(EmptySparseMatrixSerializeTest)
{
- arma::mat m;
- TestArmadilloSerialization<arma::mat, xml_iarchive, xml_oarchive>(m);
+ arma::sp_mat m;
+ TestAllArmadilloSerialization(m);
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list