[mlpack-git] master: Add _impl.hpp file and serialization. (982e17e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 16 14:12:36 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cd5986e141b41781fdc13a9c89443f9be33b56bd...31c10fef76ac1d85c6415c92d2ccd429c430105f
>---------------------------------------------------------------
commit 982e17e3bbb993b44cee46b7ce64eef24b93b384
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Dec 16 16:17:31 2015 +0000
Add _impl.hpp file and serialization.
>---------------------------------------------------------------
982e17e3bbb993b44cee46b7ce64eef24b93b384
src/mlpack/methods/local_coordinate_coding/lcc.hpp | 4 +
.../methods/local_coordinate_coding/lcc_impl.hpp | 117 +++++++++++++++++++++
src/mlpack/tests/local_coordinate_coding_test.cpp | 44 ++++++++
3 files changed, 165 insertions(+)
diff --git a/src/mlpack/methods/local_coordinate_coding/lcc.hpp b/src/mlpack/methods/local_coordinate_coding/lcc.hpp
index ec1e801..6738ca9 100644
--- a/src/mlpack/methods/local_coordinate_coding/lcc.hpp
+++ b/src/mlpack/methods/local_coordinate_coding/lcc.hpp
@@ -189,6 +189,10 @@ class LocalCoordinateCoding
//! Modify the objective tolerance.
double& Tolerance() { return tolerance; }
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! Number of atoms in dictionary.
size_t atoms;
diff --git a/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp b/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
new file mode 100644
index 0000000..531474d
--- /dev/null
+++ b/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
@@ -0,0 +1,117 @@
+/**
+ * @file lcc_impl.hpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Local Coordinate Coding
+ */
+#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "lcc.hpp"
+
+namespace mlpack {
+namespace lcc {
+
+template<typename DictionaryInitializer>
+LocalCoordinateCoding::LocalCoordinateCoding(
+ const arma::mat& data,
+ const size_t atoms,
+ const double lambda,
+ const size_t maxIterations,
+ const double tolerance,
+ const DictionaryInitializer& initializer) :
+ atoms(atoms),
+ lambda(lambda),
+ maxIterations(maxIterations),
+ tolerance(tolerance)
+{
+ // Train the model.
+ Train(data, initializer);
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding::Train(
+ const arma::mat& data,
+ const DictionaryInitializer& initializer)
+{
+ Timer::Start("local_coordinate_coding");
+
+ // Initialize the dictionary.
+ initializer.Initialize(data, atoms, dictionary);
+
+ double lastObjVal = DBL_MAX;
+
+ // Take the initial coding step, which has to happen before entering the main
+ // loop.
+ Log::Info << "Initial Coding Step." << std::endl;
+
+ arma::mat codes;
+ Encode(data, codes);
+ arma::uvec adjacencies = find(codes);
+
+ Log::Info << " Sparsity level: " << 100.0 * ((double)(adjacencies.n_elem)) /
+ ((double)(atoms * data.n_cols)) << "%.\n";
+ Log::Info << " Objective value: " << Objective(data, codes, adjacencies)
+ << "." << std::endl;
+
+ for (size_t t = 1; t != maxIterations; t++)
+ {
+ Log::Info << "Iteration " << t << " of " << maxIterations << "."
+ << std::endl;
+
+ // First step: optimize the dictionary.
+ Log::Info << "Performing dictionary step..." << std::endl;
+ OptimizeDictionary(data, codes, adjacencies);
+ double dsObjVal = Objective(data, codes, adjacencies);
+ Log::Info << " Objective value: " << dsObjVal << "." << std::endl;
+
+ // Second step: perform the coding.
+ Log::Info << "Performing coding step..." << std::endl;
+ Encode(data, codes);
+ adjacencies = find(codes);
+ Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
+ / ((double)(atoms * data.n_cols)) << "%.\n";
+
+ // Terminate if the objective increased in the coding step.
+ double curObjVal = Objective(data, codes, adjacencies);
+ if (curObjVal > dsObjVal)
+ {
+ Log::Warn << "Objective increased in coding step! Terminating."
+ << std::endl;
+ break;
+ }
+
+ // Find the new objective value and improvement so we can check for
+ // convergence.
+ double improvement = lastObjVal - curObjVal;
+ Log::Info << "Objective value: " << curObjVal << " (improvement "
+ << std::scientific << improvement << ")." << std::endl;
+
+ if (improvement < tolerance)
+ {
+ Log::Info << "Converged within tolerance " << tolerance << ".\n";
+ break;
+ }
+
+ lastObjVal = curObjVal;
+ }
+
+ Timer::Stop("local_coordinate_coding");
+}
+
+template<typename Archive>
+void LocalCoordinateCoding::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(atoms, "atoms");
+ ar & data::CreateNVP(dictionary, "dictionary");
+ ar & data::CreateNVP(lambda, "lambda");
+ ar & data::CreateNVP(maxIterations, "maxIterations");
+ ar & data::CreateNVP(tolerance, "tolerance");
+}
+
+} // namespace lcc
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/local_coordinate_coding_test.cpp b/src/mlpack/tests/local_coordinate_coding_test.cpp
index 98fea7f..2e51d90 100644
--- a/src/mlpack/tests/local_coordinate_coding_test.cpp
+++ b/src/mlpack/tests/local_coordinate_coding_test.cpp
@@ -11,6 +11,7 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
using namespace arma;
using namespace mlpack;
@@ -117,4 +118,47 @@ BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestDictionaryStep)
BOOST_REQUIRE_SMALL(norm(grad, "fro"), tol);
}
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ mat X = randu<mat>(100, 100);
+ size_t nAtoms = 25;
+
+ LocalCoordinateCoding lcc(nAtoms, 0.05);
+ lcc.Train(X);
+
+ mat Y = randu<mat>(100, 200);
+ mat codes;
+ lcc.Encode(Y, codes);
+
+ LocalCoordinateCoding lccXml(50, 0.1), lccText(12, 0.0), lccBinary(0, 0.0);
+ SerializeObjectAll(lcc, lccXml, lccText, lccBinary);
+
+ CheckMatrices(lcc.Dictionary(), lccXml.Dictionary(), lccText.Dictionary(),
+ lccBinary.Dictionary());
+
+ mat xmlCodes, textCodes, binaryCodes;
+ lccXml.Encode(Y, xmlCodes);
+ lccText.Encode(Y, textCodes);
+ lccBinary.Encode(Y, binaryCodes);
+
+ CheckMatrices(codes, xmlCodes, textCodes, binaryCodes);
+
+ // Check the parameters, too.
+ BOOST_REQUIRE_EQUAL(lcc.Atoms(), lccXml.Atoms());
+ BOOST_REQUIRE_EQUAL(lcc.Atoms(), lccText.Atoms());
+ BOOST_REQUIRE_EQUAL(lcc.Atoms(), lccBinary.Atoms());
+
+ BOOST_REQUIRE_CLOSE(lcc.Tolerance(), lccXml.Tolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lcc.Tolerance(), lccText.Tolerance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lcc.Tolerance(), lccBinary.Tolerance(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(lcc.Lambda(), lccXml.Lambda(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lcc.Lambda(), lccText.Lambda(), 1e-5);
+ BOOST_REQUIRE_CLOSE(lcc.Lambda(), lccBinary.Lambda(), 1e-5);
+
+ BOOST_REQUIRE_EQUAL(lcc.MaxIterations(), lccXml.MaxIterations());
+ BOOST_REQUIRE_EQUAL(lcc.MaxIterations(), lccText.MaxIterations());
+ BOOST_REQUIRE_EQUAL(lcc.MaxIterations(), lccBinary.MaxIterations());
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list