[mlpack-git] master: Add serialization test. (f86d10c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 20:14:22 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e01baa6582c6ce3cfda243157159de9c30e2cbd8...f86d10c28df45b70e30d17854d85eac74d650509
>---------------------------------------------------------------
commit f86d10c28df45b70e30d17854d85eac74d650509
Author: ryan <ryan at ratml.org>
Date: Mon Dec 21 20:14:12 2015 -0500
Add serialization test.
>---------------------------------------------------------------
f86d10c28df45b70e30d17854d85eac74d650509
src/mlpack/tests/cf_test.cpp | 79 ++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 79 insertions(+)
diff --git a/src/mlpack/tests/cf_test.cpp b/src/mlpack/tests/cf_test.cpp
index 8aad547..834e326 100644
--- a/src/mlpack/tests/cf_test.cpp
+++ b/src/mlpack/tests/cf_test.cpp
@@ -11,6 +11,7 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
BOOST_AUTO_TEST_SUITE(CFTest);
@@ -460,4 +461,82 @@ BOOST_AUTO_TEST_CASE(EmptyConstructorTrainTest)
}
}
+/**
+ * Ensure we can load and save the CF model.
+ */
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ // Load a dataset to train on.
+ arma::mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ arma::sp_mat cleanedData;
+ CF::CleanData(dataset, cleanedData);
+
+ CF c(cleanedData);
+
+ arma::sp_mat randomData;
+ randomData.sprandu(100, 100, 0.3);
+
+ CF cXml(randomData);
+ CF cBinary;
+ CF cText(cleanedData, amf::NMFALSFactorizer(), 5, 5);
+
+ SerializeObjectAll(c, cXml, cText, cBinary);
+
+ // Check the internals.
+ BOOST_REQUIRE_EQUAL(c.NumUsersForSimilarity(), cXml.NumUsersForSimilarity());
+ BOOST_REQUIRE_EQUAL(c.NumUsersForSimilarity(),
+ cBinary.NumUsersForSimilarity());
+ BOOST_REQUIRE_EQUAL(c.NumUsersForSimilarity(), cText.NumUsersForSimilarity());
+
+ BOOST_REQUIRE_EQUAL(c.Rank(), cXml.Rank());
+ BOOST_REQUIRE_EQUAL(c.Rank(), cBinary.Rank());
+ BOOST_REQUIRE_EQUAL(c.Rank(), cText.Rank());
+
+ CheckMatrices(c.W(), cXml.W(), cBinary.W(), cText.W());
+ CheckMatrices(c.H(), cXml.H(), cBinary.H(), cText.H());
+
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_rows, cXml.CleanedData().n_rows);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_rows, cBinary.CleanedData().n_rows);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_rows, cText.CleanedData().n_rows);
+
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_cols, cXml.CleanedData().n_cols);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_cols, cBinary.CleanedData().n_cols);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_cols, cText.CleanedData().n_cols);
+
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_nonzero, cXml.CleanedData().n_nonzero);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_nonzero,
+ cBinary.CleanedData().n_nonzero);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().n_nonzero, cText.CleanedData().n_nonzero);
+
+ for (size_t i = 0; i <= c.CleanedData().n_cols; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(c.CleanedData().col_ptrs[i],
+ cXml.CleanedData().col_ptrs[i]);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().col_ptrs[i],
+ cBinary.CleanedData().col_ptrs[i]);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().col_ptrs[i],
+ cText.CleanedData().col_ptrs[i]);
+ }
+
+ for (size_t i = 0; i <= c.CleanedData().n_nonzero; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(c.CleanedData().row_indices[i],
+ cXml.CleanedData().row_indices[i]);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().row_indices[i],
+ cBinary.CleanedData().row_indices[i]);
+ BOOST_REQUIRE_EQUAL(c.CleanedData().row_indices[i],
+ cText.CleanedData().row_indices[i]);
+
+ BOOST_REQUIRE_CLOSE(c.CleanedData().values[i], cXml.CleanedData().values[i],
+ 1e-5);
+ BOOST_REQUIRE_CLOSE(c.CleanedData().values[i],
+ cBinary.CleanedData().values[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(c.CleanedData().values[i],
+ cText.CleanedData().values[i], 1e-5);
+ }
+}
+
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list