[mlpack-git] master: Refactor serialization tests; add SparseCoding implementation. (fc50782)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 8 11:11:11 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/157595c68e3d26679e90152f07e1ee28e5e563c2...fc50782ab165567b0f04b11534b4ddc499262330
>---------------------------------------------------------------
commit fc50782ab165567b0f04b11534b4ddc499262330
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Dec 8 16:00:53 2015 +0000
Refactor serialization tests; add SparseCoding implementation.
>---------------------------------------------------------------
fc50782ab165567b0f04b11534b4ddc499262330
src/mlpack/methods/sparse_coding/sparse_coding.cpp | 283 +++++++++++++++++++++
src/mlpack/tests/CMakeLists.txt | 2 +
src/mlpack/tests/serialization.cpp | 76 ++++++
src/mlpack/tests/serialization.hpp | 199 +++++++++++++++
src/mlpack/tests/serialization_test.cpp | 226 +---------------
5 files changed, 561 insertions(+), 225 deletions(-)
diff --git a/src/mlpack/methods/sparse_coding/sparse_coding.cpp b/src/mlpack/methods/sparse_coding/sparse_coding.cpp
new file mode 100644
index 0000000..9349b7f
--- /dev/null
+++ b/src/mlpack/methods/sparse_coding/sparse_coding.cpp
@@ -0,0 +1,283 @@
+/**
+ * @file sparse_coding.cpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
+ * l1+l2 (Elastic Net) regularization.
+ */
+#include "sparse_coding.hpp"
+
+namespace mlpack {
+namespace sparse_coding {
+
+SparseCoding::SparseCoding(
+ const size_t atoms,
+ const double lambda1,
+ const double lambda2,
+ const size_t maxIterations,
+ const double objTolerance,
+ const double newtonTolerance) :
+ atoms(atoms),
+ lambda1(lambda1),
+ lambda2(lambda2),
+ maxIterations(maxIterations),
+ objTolerance(objTolerance),
+ newtonTolerance(newtonTolerance)
+{
+ // Nothing to do.
+}
+
+void SparseCoding::OptimizeCode(const arma::mat& data, arma::mat& codes)
+{
+ // When using the Cholesky version of LARS, this is correct even if
+ // lambda2 > 0.
+ arma::mat matGram = trans(dictionary) * dictionary;
+
+ codes.set_size(atoms, data.n_cols);
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Report progress.
+ if ((i % 100) == 0)
+ Log::Debug << "Optimization at point " << i << "." << std::endl;
+
+ bool useCholesky = true;
+ regression::LARS lars(useCholesky, matGram, lambda1, lambda2);
+
+ // Create an alias of the code (using the same memory), and then LARS will
+ // place the result directly into that; then we will not need to have an
+ // extra copy.
+ arma::vec code = codes.unsafe_col(i);
+ lars.Train(dictionary, data.unsafe_col(i), code, false);
+ }
+}
+
+// Dictionary step for optimization.
+double SparseCoding::OptimizeDictionary(const arma::mat& data,
+ const arma::mat& codes,
+ const arma::uvec& adjacencies,
+ const double newtonTolerance,
+ const size_t maxIterations)
+{
+ // Count the number of atomic neighbors for each point x^i.
+ arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
+
+ if (adjacencies.n_elem > 0)
+ {
+ // This gets the column index. Intentional integer division.
+ size_t curPointInd = (size_t) (adjacencies(0) / atoms);
+
+ size_t nextColIndex = (curPointInd + 1) * atoms;
+ for (size_t l = 1; l < adjacencies.n_elem; ++l)
+ {
+ // If l no longer refers to an element in this column, advance the column
+ // number accordingly.
+ if (adjacencies(l) >= nextColIndex)
+ {
+ curPointInd = (size_t) (adjacencies(l) / atoms);
+ nextColIndex = (curPointInd + 1) * atoms;
+ }
+
+ ++neighborCounts(curPointInd);
+ }
+ }
+
+ // Handle the case of inactive atoms (atoms not used in the given coding).
+ std::vector<size_t> inactiveAtoms;
+
+ for (size_t j = 0; j < atoms; ++j)
+ {
+ if (arma::accu(codes.row(j) != 0) == 0)
+ inactiveAtoms.push_back(j);
+ }
+
+ const size_t nInactiveAtoms = inactiveAtoms.size();
+ const size_t nActiveAtoms = atoms - nInactiveAtoms;
+
+ // Efficient construction of Z restricted to active atoms.
+ arma::mat matActiveZ;
+ if (nInactiveAtoms > 0)
+ {
+ math::RemoveRows(codes, inactiveAtoms, matActiveZ);
+ }
+
+ if (nInactiveAtoms > 0)
+ {
+ Log::Warn << "There are " << nInactiveAtoms
+ << " inactive atoms. They will be re-initialized randomly.\n";
+ }
+
+ Log::Debug << "Solving Dual via Newton's Method.\n";
+
+ // Solve using Newton's method in the dual - note that the final dot
+ // multiplication with inv(A) seems to be unavoidable. Although more
+ // expensive, the code written this way (we use solve()) should be more
+ // numerically stable than just using inv(A) for everything.
+ arma::vec dualVars = arma::zeros<arma::vec>(nActiveAtoms);
+
+ //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
+
+ // Method used by feature sign code - fails miserably here. Perhaps the
+ // MATLAB optimizer fmincon does something clever?
+ //vec dualVars = 10.0 * randu(nActiveAtoms, 1);
+
+ //vec dualVars = diagvec(solve(dictionary, data * trans(codes))
+ // - codes * trans(codes));
+ //for (size_t i = 0; i < dualVars.n_elem; i++)
+ // if (dualVars(i) < 0)
+ // dualVars(i) = 0;
+
+ bool converged = false;
+
+ // If we have any inactive atoms, we must construct these differently.
+ arma::mat codesXT;
+ arma::mat codesZT;
+
+ if (inactiveAtoms.empty())
+ {
+ codesXT = codes * trans(data);
+ codesZT = codes * trans(codes);
+ }
+ else
+ {
+ codesXT = matActiveZ * trans(data);
+ codesZT = matActiveZ * trans(matActiveZ);
+ }
+
+ double normGradient = 0;
+ double improvement = 0;
+ for (size_t t = 1; (t != maxIterations) && !converged; ++t)
+ {
+ arma::mat A = codesZT + diagmat(dualVars);
+
+ arma::mat matAInvZXT = solve(A, codesXT);
+
+ arma::vec gradient = -arma::sum(arma::square(matAInvZXT), 1);
+ gradient += 1;
+
+ arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
+
+ arma::vec searchDirection = -solve(hessian, gradient);
+ //printf("%e\n", norm(searchDirection, 2));
+
+ // Armijo line search.
+ const double c = 1e-4;
+ double alpha = 1.0;
+ const double rho = 0.9;
+ double sufficientDecrease = c * dot(gradient, searchDirection);
+
+ // A maxIterations parameter for the Armijo line search may be a good idea,
+ // but it doesn't seem to be causing any problems for now.
+ while (true)
+ {
+ // Calculate objective.
+ double sumDualVars = arma::sum(dualVars);
+ double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
+ double fNew = -(-trace(trans(codesXT) * solve(codesZT +
+ diagmat(dualVars + alpha * searchDirection), codesXT)) -
+ (sumDualVars + alpha * arma::sum(searchDirection)));
+
+ if (fNew <= fOld + alpha * sufficientDecrease)
+ {
+ searchDirection = alpha * searchDirection;
+ improvement = fOld - fNew;
+ break;
+ }
+
+ alpha *= rho;
+ }
+
+ // Take step and print useful information.
+ dualVars += searchDirection;
+ normGradient = arma::norm(gradient, 2);
+ Log::Debug << "Newton Method iteration " << t << ":" << std::endl;
+ Log::Debug << " Gradient norm: " << std::scientific << normGradient
+ << "." << std::endl;
+ Log::Debug << " Improvement: " << std::scientific << improvement << ".\n";
+
+ if (normGradient < newtonTolerance)
+ converged = true;
+ }
+
+ if (inactiveAtoms.empty())
+ {
+ // Directly update dictionary.
+ dictionary = trans(solve(codesZT + diagmat(dualVars), codesXT));
+ }
+ else
+ {
+ arma::mat activeDictionary = trans(solve(codesZT +
+ diagmat(dualVars), codesXT));
+
+ // Update all atoms.
+ size_t currentInactiveIndex = 0;
+ for (size_t i = 0; i < atoms; ++i)
+ {
+ if (inactiveAtoms[currentInactiveIndex] == i)
+ {
+ // This atom is inactive. Reinitialize it randomly.
+ dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)));
+
+ dictionary.col(i) /= arma::norm(dictionary.col(i), 2);
+
+ // Increment inactive index counter.
+ ++currentInactiveIndex;
+ }
+ else
+ {
+ // Update estimate.
+ dictionary.col(i) = activeDictionary.col(i - currentInactiveIndex);
+ }
+ }
+ }
+
+ return normGradient;
+}
+
+// Project each atom of the dictionary back into the unit ball (if necessary).
+void SparseCoding::ProjectDictionary()
+{
+ for (size_t j = 0; j < atoms; j++)
+ {
+ double atomNorm = arma::norm(dictionary.col(j), 2);
+ if (atomNorm > 1)
+ {
+ Log::Info << "Norm of atom " << j << " exceeds 1 (" << std::scientific
+ << atomNorm << "). Shrinking...\n";
+ dictionary.col(j) /= atomNorm;
+ }
+ }
+}
+
+// Compute the objective function.
+double SparseCoding::Objective(const arma::mat& data, const arma::mat& codes)
+ const
+{
+ double l11NormZ = arma::sum(arma::sum(arma::abs(codes)));
+ double froNormResidual = arma::norm(data - (dictionary * codes), "fro");
+
+ if (lambda2 > 0)
+ {
+ double froNormZ = arma::norm(codes, "fro");
+ return 0.5 * (std::pow(froNormResidual, 2.0) + (lambda2 *
+ std::pow(froNormZ, 2.0))) + (lambda1 * l11NormZ);
+ }
+ else // It can be simpler.
+ {
+ return 0.5 * std::pow(froNormResidual, 2.0) + lambda1 * l11NormZ;
+ }
+}
+
+std::string SparseCoding::ToString() const
+{
+ std::ostringstream convert;
+ convert << "Sparse Coding [" << this << "]" << std::endl;
+ convert << " Atoms: " << atoms << std::endl;
+ convert << " Lambda 1: " << lambda1 << std::endl;
+ convert << " Lambda 2: " << lambda2 << std::endl;
+ return convert.str();
+}
+
+} // namespace sparse_coding
+} // namespace mlpack
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index b4f3a37..06d2515 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -57,6 +57,8 @@ add_executable(mlpack_test
sa_test.cpp
sdp_primal_dual_test.cpp
sgd_test.cpp
+ serialization.hpp
+ serialization.cpp
serialization_test.cpp
softmax_regression_test.cpp
sort_policy_test.cpp
diff --git a/src/mlpack/tests/serialization.cpp b/src/mlpack/tests/serialization.cpp
new file mode 100644
index 0000000..b32b730
--- /dev/null
+++ b/src/mlpack/tests/serialization.cpp
@@ -0,0 +1,76 @@
+/**
+ * @file serialization.cpp
+ * @author Ryan Curtin
+ *
+ * Miscellaneous utility functions for serialization tests.
+ */
+#include "serialization.hpp"
+
+namespace mlpack {
+
+// Utility function to check the equality of two Armadillo matrices.
+void CheckMatrices(const arma::mat& x,
+ const arma::mat& xmlX,
+ const arma::mat& textX,
+ const arma::mat& binaryX)
+{
+ // First check dimensions.
+ BOOST_REQUIRE_EQUAL(x.n_rows, xmlX.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_rows, textX.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_rows, binaryX.n_rows);
+
+ BOOST_REQUIRE_EQUAL(x.n_cols, xmlX.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_cols, textX.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_cols, binaryX.n_cols);
+
+ BOOST_REQUIRE_EQUAL(x.n_elem, xmlX.n_elem);
+ BOOST_REQUIRE_EQUAL(x.n_elem, textX.n_elem);
+ BOOST_REQUIRE_EQUAL(x.n_elem, binaryX.n_elem);
+
+ // Now check elements.
+ for (size_t i = 0; i < x.n_elem; ++i)
+ {
+ const double val = x[i];
+ if (val == 0.0)
+ {
+ BOOST_REQUIRE_SMALL(xmlX[i], 1e-8);
+ BOOST_REQUIRE_SMALL(textX[i], 1e-8);
+ BOOST_REQUIRE_SMALL(binaryX[i], 1e-8);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(val, xmlX[i], 1e-8);
+ BOOST_REQUIRE_CLOSE(val, textX[i], 1e-8);
+ BOOST_REQUIRE_CLOSE(val, binaryX[i], 1e-8);
+ }
+ }
+}
+
+void CheckMatrices(const arma::Mat<size_t>& x,
+ const arma::Mat<size_t>& xmlX,
+ const arma::Mat<size_t>& textX,
+ const arma::Mat<size_t>& binaryX)
+{
+ // First check dimensions.
+ BOOST_REQUIRE_EQUAL(x.n_rows, xmlX.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_rows, textX.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_rows, binaryX.n_rows);
+
+ BOOST_REQUIRE_EQUAL(x.n_cols, xmlX.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_cols, textX.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_cols, binaryX.n_cols);
+
+ BOOST_REQUIRE_EQUAL(x.n_elem, xmlX.n_elem);
+ BOOST_REQUIRE_EQUAL(x.n_elem, textX.n_elem);
+ BOOST_REQUIRE_EQUAL(x.n_elem, binaryX.n_elem);
+
+ // Now check elements.
+ for (size_t i = 0; i < x.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(x[i], xmlX[i]);
+ BOOST_REQUIRE_EQUAL(x[i], textX[i]);
+ BOOST_REQUIRE_EQUAL(x[i], binaryX[i]);
+ }
+}
+
+} // namespace mlpack
diff --git a/src/mlpack/tests/serialization.hpp b/src/mlpack/tests/serialization.hpp
new file mode 100644
index 0000000..530217b
--- /dev/null
+++ b/src/mlpack/tests/serialization.hpp
@@ -0,0 +1,199 @@
+/**
+ * @file serialization.hpp
+ * @author Ryan Curtin
+ *
+ * Miscellaneous utility functions for serialization tests.
+ */
+#ifndef __MLPACK_TESTS_SERIALIZATION_HPP
+#define __MLPACK_TESTS_SERIALIZATION_HPP
+
+#include <boost/serialization/serialization.hpp>
+#include <boost/archive/xml_iarchive.hpp>
+#include <boost/archive/xml_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <mlpack/core.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+namespace mlpack {
+
+// Test function for loading and saving Armadillo objects.
+template<typename MatType,
+ typename IArchiveType,
+ typename OArchiveType>
+void TestArmadilloSerialization(MatType& x)
+{
+ // First save it.
+ std::ofstream ofs("test", std::ios::binary);
+ OArchiveType o(ofs);
+
+ bool success = true;
+ try
+ {
+ o << BOOST_SERIALIZATION_NVP(x);
+ }
+ catch (boost::archive::archive_exception& e)
+ {
+ success = false;
+ }
+
+ BOOST_REQUIRE_EQUAL(success, true);
+ ofs.close();
+
+ // Now load it.
+ MatType orig(x);
+ success = true;
+ std::ifstream ifs("test", std::ios::binary);
+ IArchiveType i(ifs);
+
+ try
+ {
+ i >> BOOST_SERIALIZATION_NVP(x);
+ }
+ catch (boost::archive::archive_exception& e)
+ {
+ success = false;
+ }
+
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
+ BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
+ BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
+
+ for (size_t i = 0; i < x.n_cols; ++i)
+ for (size_t j = 0; j < x.n_rows; ++j)
+ if (double(orig(j, i)) == 0.0)
+ BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
+ else
+ BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
+
+ remove("test");
+}
+
+// Test all serialization strategies.
+template<typename MatType>
+void TestAllArmadilloSerialization(MatType& x)
+{
+ TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
+ boost::archive::xml_oarchive>(x);
+ TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
+ boost::archive::text_oarchive>(x);
+ TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
+ boost::archive::binary_oarchive>(x);
+}
+
+// Save and load an mlpack object.
+// The re-loaded copy is placed in 'newT'.
+template<typename T, typename IArchiveType, typename OArchiveType>
+void SerializeObject(T& t, T& newT)
+{
+ std::ofstream ofs("test", std::ios::binary);
+ OArchiveType o(ofs);
+
+ bool success = true;
+ try
+ {
+ o << data::CreateNVP(t, "t");
+ }
+ catch (boost::archive::archive_exception& e)
+ {
+ success = false;
+ }
+ ofs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ std::ifstream ifs("test", std::ios::binary);
+ IArchiveType i(ifs);
+
+ try
+ {
+ i >> data::CreateNVP(newT, "t");
+ }
+ catch (boost::archive::archive_exception& e)
+ {
+ success = false;
+ }
+ ifs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+}
+
+// Test mlpack serialization with all three archive types.
+template<typename T>
+void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
+{
+ SerializeObject<T, boost::archive::text_iarchive,
+ boost::archive::text_oarchive>(t, textT);
+ SerializeObject<T, boost::archive::binary_iarchive,
+ boost::archive::binary_oarchive>(t, binaryT);
+ SerializeObject<T, boost::archive::xml_iarchive,
+ boost::archive::xml_oarchive>(t, xmlT);
+}
+
+// Save and load a non-default-constructible mlpack object.
+template<typename T, typename IArchiveType, typename OArchiveType>
+void SerializePointerObject(T* t, T*& newT)
+{
+ std::ofstream ofs("test", std::ios::binary);
+ OArchiveType o(ofs);
+
+ bool success = true;
+ try
+ {
+ o << data::CreateNVP(*t, "t");
+ }
+ catch (boost::archive::archive_exception& e)
+ {
+ success = false;
+ }
+ ofs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ std::ifstream ifs("test", std::ios::binary);
+ IArchiveType i(ifs);
+
+ try
+ {
+ newT = new T(i);
+ }
+ catch (std::exception& e)
+ {
+ success = false;
+ }
+ ifs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+}
+
+template<typename T>
+void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
+{
+ SerializePointerObject<T, boost::archive::text_iarchive,
+ boost::archive::text_oarchive>(t, textT);
+ SerializePointerObject<T, boost::archive::binary_iarchive,
+ boost::archive::binary_oarchive>(t, binaryT);
+ SerializePointerObject<T, boost::archive::xml_iarchive,
+ boost::archive::xml_oarchive>(t, xmlT);
+}
+
+// Utility function to check the equality of two Armadillo matrices.
+void CheckMatrices(const arma::mat& x,
+ const arma::mat& xmlX,
+ const arma::mat& textX,
+ const arma::mat& binaryX);
+
+void CheckMatrices(const arma::Mat<size_t>& x,
+ const arma::Mat<size_t>& xmlX,
+ const arma::Mat<size_t>& textX,
+ const arma::Mat<size_t>& binaryX);
+
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 4a9a182..a278249 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -4,17 +4,11 @@
*
* Test serialization of mlpack objects.
*/
-#include <boost/serialization/serialization.hpp>
-#include <boost/archive/xml_iarchive.hpp>
-#include <boost/archive/xml_oarchive.hpp>
-#include <boost/archive/text_iarchive.hpp>
-#include <boost/archive/text_oarchive.hpp>
-#include <boost/archive/binary_iarchive.hpp>
-#include <boost/archive/binary_oarchive.hpp>
#include <mlpack/core.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
#include <mlpack/core/dists/regression_distribution.hpp>
#include <mlpack/core/tree/ballbound.hpp>
@@ -55,69 +49,6 @@ using namespace std;
BOOST_AUTO_TEST_SUITE(SerializationTest);
-// Test function for loading and saving Armadillo objects.
-template<typename MatType,
- typename IArchiveType,
- typename OArchiveType>
-void TestArmadilloSerialization(MatType& x)
-{
- // First save it.
- ofstream ofs("test", ios::binary);
- OArchiveType o(ofs);
-
- bool success = true;
- try
- {
- o << BOOST_SERIALIZATION_NVP(x);
- }
- catch (archive_exception& e)
- {
- success = false;
- }
-
- BOOST_REQUIRE_EQUAL(success, true);
- ofs.close();
-
- // Now load it.
- MatType orig(x);
- success = true;
- ifstream ifs("test", ios::binary);
- IArchiveType i(ifs);
-
- try
- {
- i >> BOOST_SERIALIZATION_NVP(x);
- }
- catch (archive_exception& e)
- {
- success = false;
- }
-
- BOOST_REQUIRE_EQUAL(success, true);
-
- BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
- BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
- BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
-
- for (size_t i = 0; i < x.n_cols; ++i)
- for (size_t j = 0; j < x.n_rows; ++j)
- if (double(orig(j, i)) == 0.0)
- BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
- else
- BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
-
- remove("test");
-}
-
-// Test all serialization strategies.
-template<typename MatType>
-void TestAllArmadilloSerialization(MatType& x)
-{
- TestArmadilloSerialization<MatType, xml_iarchive, xml_oarchive>(x);
- TestArmadilloSerialization<MatType, text_iarchive, text_oarchive>(x);
- TestArmadilloSerialization<MatType, binary_iarchive, binary_oarchive>(x);
-}
-
/**
* Can we load and save an Armadillo matrix?
*/
@@ -192,161 +123,6 @@ BOOST_AUTO_TEST_CASE(EmptySparseMatrixSerializeTest)
TestAllArmadilloSerialization(m);
}
-// Save and load an mlpack object.
-// The re-loaded copy is placed in 'newT'.
-template<typename T, typename IArchiveType, typename OArchiveType>
-void SerializeObject(T& t, T& newT)
-{
- ofstream ofs("test", ios::binary);
- OArchiveType o(ofs);
-
- bool success = true;
- try
- {
- o << data::CreateNVP(t, "t");
- }
- catch (archive_exception& e)
- {
- success = false;
- }
- ofs.close();
-
- BOOST_REQUIRE_EQUAL(success, true);
-
- ifstream ifs("test", ios::binary);
- IArchiveType i(ifs);
-
- try
- {
- i >> data::CreateNVP(newT, "t");
- }
- catch (archive_exception& e)
- {
- success = false;
- }
- ifs.close();
-
- BOOST_REQUIRE_EQUAL(success, true);
-}
-
-// Test mlpack serialization with all three archive types.
-template<typename T>
-void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
-{
- SerializeObject<T, text_iarchive, text_oarchive>(t, textT);
- SerializeObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
- SerializeObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
-}
-
-// Save and load a non-default-constructible mlpack object.
-template<typename T, typename IArchiveType, typename OArchiveType>
-void SerializePointerObject(T* t, T*& newT)
-{
- ofstream ofs("test", ios::binary);
- OArchiveType o(ofs);
-
- bool success = true;
- try
- {
- o << data::CreateNVP(*t, "t");
- }
- catch (archive_exception& e)
- {
- success = false;
- }
- ofs.close();
-
- BOOST_REQUIRE_EQUAL(success, true);
-
- ifstream ifs("test", ios::binary);
- IArchiveType i(ifs);
-
- try
- {
- newT = new T(i);
- }
- catch (std::exception& e)
- {
- success = false;
- }
- ifs.close();
-
- BOOST_REQUIRE_EQUAL(success, true);
-}
-
-template<typename T>
-void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
-{
- SerializePointerObject<T, text_iarchive, text_oarchive>(t, textT);
- SerializePointerObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
- SerializePointerObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
-}
-
-// Utility function to check the equality of two Armadillo matrices.
-void CheckMatrices(const mat& x,
- const mat& xmlX,
- const mat& textX,
- const mat& binaryX)
-{
- // First check dimensions.
- BOOST_REQUIRE_EQUAL(x.n_rows, xmlX.n_rows);
- BOOST_REQUIRE_EQUAL(x.n_rows, textX.n_rows);
- BOOST_REQUIRE_EQUAL(x.n_rows, binaryX.n_rows);
-
- BOOST_REQUIRE_EQUAL(x.n_cols, xmlX.n_cols);
- BOOST_REQUIRE_EQUAL(x.n_cols, textX.n_cols);
- BOOST_REQUIRE_EQUAL(x.n_cols, binaryX.n_cols);
-
- BOOST_REQUIRE_EQUAL(x.n_elem, xmlX.n_elem);
- BOOST_REQUIRE_EQUAL(x.n_elem, textX.n_elem);
- BOOST_REQUIRE_EQUAL(x.n_elem, binaryX.n_elem);
-
- // Now check elements.
- for (size_t i = 0; i < x.n_elem; ++i)
- {
- const double val = x[i];
- if (val == 0.0)
- {
- BOOST_REQUIRE_SMALL(xmlX[i], 1e-8);
- BOOST_REQUIRE_SMALL(textX[i], 1e-8);
- BOOST_REQUIRE_SMALL(binaryX[i], 1e-8);
- }
- else
- {
- BOOST_REQUIRE_CLOSE(val, xmlX[i], 1e-8);
- BOOST_REQUIRE_CLOSE(val, textX[i], 1e-8);
- BOOST_REQUIRE_CLOSE(val, binaryX[i], 1e-8);
- }
- }
-}
-
-void CheckMatrices(const Mat<size_t>& x,
- const Mat<size_t>& xmlX,
- const Mat<size_t>& textX,
- const Mat<size_t>& binaryX)
-{
- // First check dimensions.
- BOOST_REQUIRE_EQUAL(x.n_rows, xmlX.n_rows);
- BOOST_REQUIRE_EQUAL(x.n_rows, textX.n_rows);
- BOOST_REQUIRE_EQUAL(x.n_rows, binaryX.n_rows);
-
- BOOST_REQUIRE_EQUAL(x.n_cols, xmlX.n_cols);
- BOOST_REQUIRE_EQUAL(x.n_cols, textX.n_cols);
- BOOST_REQUIRE_EQUAL(x.n_cols, binaryX.n_cols);
-
- BOOST_REQUIRE_EQUAL(x.n_elem, xmlX.n_elem);
- BOOST_REQUIRE_EQUAL(x.n_elem, textX.n_elem);
- BOOST_REQUIRE_EQUAL(x.n_elem, binaryX.n_elem);
-
- // Now check elements.
- for (size_t i = 0; i < x.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(x[i], xmlX[i]);
- BOOST_REQUIRE_EQUAL(x[i], textX[i]);
- BOOST_REQUIRE_EQUAL(x[i], binaryX[i]);
- }
-}
-
// Now, test mlpack objects.
BOOST_AUTO_TEST_CASE(DiscreteDistributionTest)
{
More information about the mlpack-git
mailing list