[mlpack-git] master: Add Serialize() to SparseCoding. (4fb456e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Dec 11 12:47:02 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/dd7c8b93fe5f299cb534cda70c1c786456f9a78f...3b926fd86ab143eb8af7327b9fb89fead7538df0
>---------------------------------------------------------------
commit 4fb456e21a66c1b60db7ee202eb6118b0e80e050
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Dec 11 04:49:46 2015 +0000
Add Serialize() to SparseCoding.
>---------------------------------------------------------------
4fb456e21a66c1b60db7ee202eb6118b0e80e050
src/mlpack/methods/sparse_coding/sparse_coding.hpp | 43 ++++++++++++++++++
src/mlpack/tests/sparse_coding_test.cpp | 52 ++++++++++++++++++++++
2 files changed, 95 insertions(+)
diff --git a/src/mlpack/methods/sparse_coding/sparse_coding.hpp b/src/mlpack/methods/sparse_coding/sparse_coding.hpp
index 07c5d07..f586ac2 100644
--- a/src/mlpack/methods/sparse_coding/sparse_coding.hpp
+++ b/src/mlpack/methods/sparse_coding/sparse_coding.hpp
@@ -214,6 +214,49 @@ class SparseCoding
//! Modify the dictionary.
arma::mat& Dictionary() { return dictionary; }
+ //! Access the number of atoms.
+ size_t Atoms() const { return atoms; }
+ //! Modify the number of atoms.
+ size_t& Atoms() { return atoms; }
+
+ //! Access the L1 regularization term.
+ double Lambda1() const { return lambda1; }
+ //! Modify the L1 regularization term.
+ double& Lambda1() { return lambda1; }
+
+ //! Access the L2 regularization term.
+ double Lambda2() const { return lambda2; }
+ //! Modify the L2 regularization term.
+ double& Lambda2() { return lambda2; }
+
+ //! Get the maximum number of iterations.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the objective tolerance.
+ double ObjTolerance() const { return objTolerance; }
+ //! Modify the objective tolerance.
+ double& ObjTolerance() { return objTolerance; }
+
+ //! Get the tolerance for Newton's method (dictionary optimization step).
+ double NewtonTolerance() const { return newtonTolerance; }
+ //! Modify the tolerance for Newton's method (dictionary optimization step).
+ double& NewtonTolerance() { return newtonTolerance; }
+
+ //! Serialize the sparse coding model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(atoms, "atoms");
+ ar & data::CreateNVP(dictionary, "dictionary");
+ ar & data::CreateNVP(lambda1, "lambda1");
+ ar & data::CreateNVP(lambda2, "lambda2");
+ ar & data::CreateNVP(maxIterations, "maxIterations");
+ ar & data::CreateNVP(objTolerance, "objTolerance");
+ ar & data::CreateNVP(newtonTolerance, "newtonTolerance");
+ }
+
// Returns a string representation of this object.
std::string ToString() const;
diff --git a/src/mlpack/tests/sparse_coding_test.cpp b/src/mlpack/tests/sparse_coding_test.cpp
index 0af0e30..716e434 100644
--- a/src/mlpack/tests/sparse_coding_test.cpp
+++ b/src/mlpack/tests/sparse_coding_test.cpp
@@ -12,6 +12,7 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
using namespace arma;
using namespace mlpack;
@@ -131,5 +132,56 @@ BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
BOOST_REQUIRE_SMALL(normGradient, tol);
}
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ mat X = randu<mat>(100, 100);
+ size_t nAtoms = 25;
+
+ SparseCoding sc(nAtoms, 0.05, 0.1);
+ sc.Train(X);
+
+ mat Y = randu<mat>(100, 200);
+ mat codes;
+ sc.OptimizeCode(Y, codes);
+
+ SparseCoding scXml(50, 0.01), scText(nAtoms, 0.05), scBinary(0, 0.0);
+ SerializeObjectAll(sc, scXml, scText, scBinary);
+
+ CheckMatrices(sc.Dictionary(), scXml.Dictionary(), scText.Dictionary(),
+ scBinary.Dictionary());
+
+ mat xmlCodes, textCodes, binaryCodes;
+ scXml.OptimizeCode(Y, xmlCodes);
+ scText.OptimizeCode(Y, textCodes);
+ scBinary.OptimizeCode(Y, binaryCodes);
+
+ CheckMatrices(codes, xmlCodes, textCodes, binaryCodes);
+
+ // Check the parameters, too.
+ BOOST_REQUIRE_EQUAL(sc.Atoms(), scXml.Atoms());
+ BOOST_REQUIRE_EQUAL(sc.Atoms(), scText.Atoms());
+ BOOST_REQUIRE_EQUAL(sc.Atoms(), scBinary.Atoms());
+
+ BOOST_REQUIRE_CLOSE(sc.Lambda1(), scXml.Lambda1(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.Lambda1(), scText.Lambda1(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.Lambda1(), scBinary.Lambda1(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(sc.Lambda2(), scXml.Lambda2(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.Lambda2(), scText.Lambda2(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.Lambda2(), scBinary.Lambda2(), 1e-5);
+
+ BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scXml.MaxIterations());
+ BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scText.MaxIterations());
+ BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scBinary.MaxIterations());
+
+ BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scXml.ObjTolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scText.ObjTolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scBinary.ObjTolerance(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scXml.NewtonTolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scText.NewtonTolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scBinary.NewtonTolerance(), 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list