[mlpack-git] master: Add Serialize() to SparseCoding. (4fb456e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Dec 11 12:47:02 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/dd7c8b93fe5f299cb534cda70c1c786456f9a78f...3b926fd86ab143eb8af7327b9fb89fead7538df0

>---------------------------------------------------------------

commit 4fb456e21a66c1b60db7ee202eb6118b0e80e050
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Dec 11 04:49:46 2015 +0000

    Add Serialize() to SparseCoding.


>---------------------------------------------------------------

4fb456e21a66c1b60db7ee202eb6118b0e80e050
 src/mlpack/methods/sparse_coding/sparse_coding.hpp | 43 ++++++++++++++++++
 src/mlpack/tests/sparse_coding_test.cpp            | 52 ++++++++++++++++++++++
 2 files changed, 95 insertions(+)

diff --git a/src/mlpack/methods/sparse_coding/sparse_coding.hpp b/src/mlpack/methods/sparse_coding/sparse_coding.hpp
index 07c5d07..f586ac2 100644
--- a/src/mlpack/methods/sparse_coding/sparse_coding.hpp
+++ b/src/mlpack/methods/sparse_coding/sparse_coding.hpp
@@ -214,6 +214,49 @@ class SparseCoding
   //! Modify the dictionary.
   arma::mat& Dictionary() { return dictionary; }
 
+  //! Access the number of atoms.
+  size_t Atoms() const { return atoms; }
+  //! Modify the number of atoms.
+  size_t& Atoms() { return atoms; }
+
+  //! Access the L1 regularization term.
+  double Lambda1() const { return lambda1; }
+  //! Modify the L1 regularization term.
+  double& Lambda1() { return lambda1; }
+
+  //! Access the L2 regularization term.
+  double Lambda2() const { return lambda2; }
+  //! Modify the L2 regularization term.
+  double& Lambda2() { return lambda2; }
+
+  //! Get the maximum number of iterations.
+  size_t MaxIterations() const { return maxIterations; }
+  //! Modify the maximum number of iterations.
+  size_t& MaxIterations() { return maxIterations; }
+
+  //! Get the objective tolerance.
+  double ObjTolerance() const { return objTolerance; }
+  //! Modify the objective tolerance.
+  double& ObjTolerance() { return objTolerance; }
+
+  //! Get the tolerance for Newton's method (dictionary optimization step).
+  double NewtonTolerance() const { return newtonTolerance; }
+  //! Modify the tolerance for Newton's method (dictionary optimization step).
+  double& NewtonTolerance() { return newtonTolerance; }
+
+  //! Serialize the sparse coding model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(atoms, "atoms");
+    ar & data::CreateNVP(dictionary, "dictionary");
+    ar & data::CreateNVP(lambda1, "lambda1");
+    ar & data::CreateNVP(lambda2, "lambda2");
+    ar & data::CreateNVP(maxIterations, "maxIterations");
+    ar & data::CreateNVP(objTolerance, "objTolerance");
+    ar & data::CreateNVP(newtonTolerance, "newtonTolerance");
+  }
+
   // Returns a string representation of this object.
   std::string ToString() const;
 
diff --git a/src/mlpack/tests/sparse_coding_test.cpp b/src/mlpack/tests/sparse_coding_test.cpp
index 0af0e30..716e434 100644
--- a/src/mlpack/tests/sparse_coding_test.cpp
+++ b/src/mlpack/tests/sparse_coding_test.cpp
@@ -12,6 +12,7 @@
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
 
 using namespace arma;
 using namespace mlpack;
@@ -131,5 +132,56 @@ BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
   BOOST_REQUIRE_SMALL(normGradient, tol);
 }
 
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+  mat X = randu<mat>(100, 100);
+  size_t nAtoms = 25;
+
+  SparseCoding sc(nAtoms, 0.05, 0.1);
+  sc.Train(X);
+
+  mat Y = randu<mat>(100, 200);
+  mat codes;
+  sc.OptimizeCode(Y, codes);
+
+  SparseCoding scXml(50, 0.01), scText(nAtoms, 0.05), scBinary(0, 0.0);
+  SerializeObjectAll(sc, scXml, scText, scBinary);
+
+  CheckMatrices(sc.Dictionary(), scXml.Dictionary(), scText.Dictionary(),
+      scBinary.Dictionary());
+
+  mat xmlCodes, textCodes, binaryCodes;
+  scXml.OptimizeCode(Y, xmlCodes);
+  scText.OptimizeCode(Y, textCodes);
+  scBinary.OptimizeCode(Y, binaryCodes);
+
+  CheckMatrices(codes, xmlCodes, textCodes, binaryCodes);
+
+  // Check the parameters, too.
+  BOOST_REQUIRE_EQUAL(sc.Atoms(), scXml.Atoms());
+  BOOST_REQUIRE_EQUAL(sc.Atoms(), scText.Atoms());
+  BOOST_REQUIRE_EQUAL(sc.Atoms(), scBinary.Atoms());
+
+  BOOST_REQUIRE_CLOSE(sc.Lambda1(), scXml.Lambda1(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.Lambda1(), scText.Lambda1(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.Lambda1(), scBinary.Lambda1(), 1e-5);
+
+  BOOST_REQUIRE_CLOSE(sc.Lambda2(), scXml.Lambda2(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.Lambda2(), scText.Lambda2(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.Lambda2(), scBinary.Lambda2(), 1e-5);
+
+  BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scXml.MaxIterations());
+  BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scText.MaxIterations());
+  BOOST_REQUIRE_EQUAL(sc.MaxIterations(), scBinary.MaxIterations());
+
+  BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scXml.ObjTolerance(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scText.ObjTolerance(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.ObjTolerance(), scBinary.ObjTolerance(), 1e-5);
+
+  BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scXml.NewtonTolerance(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scText.NewtonTolerance(), 1e-5);
+  BOOST_REQUIRE_CLOSE(sc.NewtonTolerance(), scBinary.NewtonTolerance(), 1e-5);
+}
+
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list