[mlpack-git] master, mlpack-1.0.x: Add sparse autoencoder contribution by Siddharth. This is the version given in ticket #345 as Changes3.tar.gz. (bdd4f83)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:46:31 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit bdd4f83edc9bded8087fd9a404e6ca90c85349ce
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Apr 16 18:30:27 2014 +0000
Add sparse autoencoder contribution by Siddharth. This is the version given in
ticket #345 as Changes3.tar.gz.
>---------------------------------------------------------------
bdd4f83edc9bded8087fd9a404e6ca90c85349ce
.../sparse_autoencoder/sparse_autoencoder.hpp | 182 +++++++++++++++++++
.../sparse_autoencoder_function.cpp | 192 +++++++++++++++++++++
.../sparse_autoencoder_function.hpp | 141 +++++++++++++++
.../sparse_autoencoder/sparse_autoencoder_impl.hpp | 62 +++++++
4 files changed, 577 insertions(+)
diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp
new file mode 100644
index 0000000..ef4742d
--- /dev/null
+++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp
@@ -0,0 +1,182 @@
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include "sparse_autoencoder_function.hpp"
+
+/**
+ * Sparse Autoencoder is a neural network whose aim to learn compressed
+ * representations of the data, typically for dimensionality reduction, with
+ * a constraint on the activity of the neurons in the network. Sparse
+ * Autoencoders can be stacked together to learn a hierarchy of features, which
+ * provide a better representation of the data for classification. This is a
+ * method used in the recently developed field of Deep Learning. More technical
+ * details about the model can be found on the following webpage:
+ * http://deeplearning.stanford.edu/wiki/index.php/UFLDL_Tutorial
+ *
+ * An example of how to use the interface is shown below:
+ * @code
+ * arma::mat data; // Data matrix.
+ * const size_t vSize = 64; // Size of visible layer, depends on the data.
+ * const size_t hSize = 25; // Size of hidden layer, depends on requirements.
+ *
+ * // Train the model using default options.
+ * SparseAutoencoder encoder1(data, vSize, hSize);
+ *
+ * const size_t numBasis = 5; // Parameter required for L-BFGS algorithm.
+ * const size_t numIterations = 100; // Maximum number of iterations.
+ *
+ * // Use an instantiated optimizer for the training.
+ * SparseAutoencoderFunction saf(data, vSize, hSize);
+ * L_BFGS<SparseAutoencoderFunction> optimizer(saf, numBasis, numIterations);
+ * SparseAutoencoder<L_BFGS> encoder2(optimizer);
+ *
+ * arma::mat features1, features2; // Matrices for storing new representations.
+ *
+ * // Get new representations from the trained models.
+ * encoder1.GetNewFeatures(data, features1);
+ * encoder2.GetNewFeatures(data, features2);
+ */
+
+namespace mlpack {
+namespace nn {
+
+template<
+ template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+>
+class SparseAutoencoder
+{
+ public:
+
+ /**
+ * Construct the Sparse Autoencoder model with the given training data. This
+ * will train the model. The parameters 'lambda', 'beta' and 'rho' can be set
+ * optionally. Changing these parameters will have an effect on regularization
+ * and sparsity of the model.
+ *
+ * @param data Input data with each column as one example.
+ * @param visibleSize Size of input vector expected at the visible layer.
+ * @param hiddenSize Size of input vector expected at the hidden layer.
+ * @param lambda L2-regularization parameter.
+ * @param beta KL divergence parameter.
+ * @param rho Sparsity parameter.
+ */
+ SparseAutoencoder(const arma::mat& data,
+ const size_t visibleSize,
+ const size_t hiddenSize,
+ const double lambda = 0.0001,
+ const double beta = 3,
+ const double rho = 0.01);
+
+ /**
+ * Construct the Sparse Autoencoder model with the given training data. This
+ * will train the model. This overload takes an already instantiated optimizer
+ * and uses it to train the model. The optimizer should hold an instantiated
+ * SparseAutoencoderFunction object for the function to operate upon. This
+ * option should be preferred when the optimizer options are to be changed.
+ *
+ * @param optimizer Instantiated optimizer with instantiated error function.
+ */
+ SparseAutoencoder(OptimizerType<SparseAutoencoderFunction>& optimizer);
+
+ /**
+ * Transforms the provided data into a more meaningful and compact
+ * representation. The function basically performs a feedforward computation
+ * using the learned weights, and returns the hidden layer activations.
+ *
+ * @param data Matrix of the provided data.
+ * @param features The hidden layer representation of the provided data.
+ */
+ void GetNewFeatures(arma::mat& data, arma::mat& features);
+
+ /**
+ * Returns the elementwise sigmoid of the passed matrix, where the sigmoid
+ * function of a real number 'x' is [1 / (1 + exp(-x))].
+ *
+ * @param x Matrix of real values for which we require the sigmoid activation.
+ */
+ arma::mat Sigmoid(const arma::mat& x) const
+ {
+ return (1.0 / (1 + arma::exp(-x)));
+ }
+
+ //! Sets size of the visible layer.
+ void VisibleSize(const size_t visible)
+ {
+ this->visibleSize = visible;
+ }
+
+ //! Gets size of the visible layer.
+ size_t VisibleSize() const
+ {
+ return visibleSize;
+ }
+
+ //! Sets size of the hidden layer.
+ void HiddenSize(const size_t hidden)
+ {
+ this->hiddenSize = hidden;
+ }
+
+ //! Gets the size of the hidden layer.
+ size_t HiddenSize() const
+ {
+ return hiddenSize;
+ }
+
+ //! Sets the L2-regularization parameter.
+ void Lambda(const double l)
+ {
+ this->lambda = l;
+ }
+
+ //! Gets the L2-regularization parameter.
+ double Lambda() const
+ {
+ return lambda;
+ }
+
+ //! Sets the KL divergence parameter.
+ void Beta(const double b)
+ {
+ this->beta = b;
+ }
+
+ //! Gets the KL divergence parameter.
+ double Beta() const
+ {
+ return beta;
+ }
+
+ //! Sets the sparsity parameter.
+ void Rho(const double r)
+ {
+ this->rho = r;
+ }
+
+ //! Gets the sparsity parameter.
+ double Rho() const
+ {
+ return rho;
+ }
+
+ private:
+
+ //! Parameters after optimization.
+ arma::mat parameters;
+ //! Size of the visible layer.
+ size_t visibleSize;
+ //! Size of the hidden layer.
+ size_t hiddenSize;
+ //! L2-regularization parameter.
+ double lambda;
+ //! KL divergence parameter.
+ double beta;
+ //! Sparsity parameter.
+ double rho;
+};
+
+}; // namespace nn
+}; // namespace mlpack
+
+// Include implementation.
+#include "sparse_autoencoder_impl.hpp"
diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp
new file mode 100644
index 0000000..4837910
--- /dev/null
+++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp
@@ -0,0 +1,192 @@
+#include "sparse_autoencoder_function.hpp"
+
+using namespace mlpack;
+using namespace mlpack::nn;
+using namespace std;
+
+SparseAutoencoderFunction::SparseAutoencoderFunction(const arma::mat& data,
+ const size_t visibleSize,
+ const size_t hiddenSize,
+ const double lambda,
+ const double beta,
+ const double rho) :
+ data(data),
+ visibleSize(visibleSize),
+ hiddenSize(hiddenSize),
+ lambda(lambda),
+ beta(beta),
+ rho(rho)
+{
+ // Initialize the parameters to suitable values.
+ initialPoint = InitializeWeights();
+}
+
+/** Initializes the parameter weights if the initial point is not passed to the
+ * constructor. The weights w1, w2 are initialized to randomly in the range
+ * [-r, r] where 'r' is decided using the sizes of the visible and hidden
+ * layers. The biases b1, b2 are initialized to 0.
+ */
+const arma::mat SparseAutoencoderFunction::InitializeWeights()
+{
+ // The module uses a matrix to store the parameters, its structure looks like:
+ // vSize 1
+ // | | |
+ // hSize| w1 |b1|
+ // |________|__|
+ // | | |
+ // hSize| w2' | |
+ // |________|__|
+ // 1| b2' | |
+ //
+ // There are (hiddenSize + 1) empty cells in the matrix, but it is small
+ // compared to the matrix size. The above structure allows for smooth matrix
+ // operations without making the code too ugly.
+
+ arma::mat parameters;
+ parameters.zeros(2*hiddenSize + 1, visibleSize + 1);
+
+ // Initialize w1 and w2 to random values in the range [0, 1].
+ parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1).randu();
+
+ // Decide the parameter 'r' depending on the size of the visible and hidden
+ // layers. The formula used is r = sqrt(6) / sqrt(vSize + hSize + 1).
+ const double range = sqrt(6) / sqrt(visibleSize + hiddenSize + 1);
+
+ //Shift range of w1 and w2 values from [0, 1] to [-r, r].
+ parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1) = 2 * range *
+ (parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1) - 0.5);
+
+ return parameters;
+}
+
+/** Evaluates the objective function given the parameters.
+ */
+double SparseAutoencoderFunction::Evaluate(const arma::mat& parameters) const
+{
+ // The objective function is the average squared reconstruction error of the
+ // network. w1 and b1 are the weights and biases associated with the hidden
+ // layer, whereas w2 and b2 are associated with the output layer.
+ // f(w1,w2,b1,b2) = sum((data - sigmoid(w2*sigmoid(w1data + b1) + b2))^2) / 2m
+ // 'm' is the number of training examples.
+ // The cost also takes into account the regularization and KL divergence terms
+ // to control the parameter weights and sparsity of the model respectively.
+
+ // Compute the limits for the parameters w1, w2, b1 and b2.
+ const size_t l1 = hiddenSize;
+ const size_t l2 = visibleSize;
+ const size_t l3 = 2*hiddenSize;
+
+ // w1, w2, b1 and b2 are not extracted separately, 'parameters' is directly
+ // used in their place to avoid copying data. The following representations
+ // are used:
+ // w1 <- parameters.submat(0, 0, l1-1, l2-1)
+ // w2 <- parameters.submat(l1, 0, l3-1, l2-1).t()
+ // b1 <- parameters.submat(0, l2, l1-1, l2)
+ // b2 <- parameters.submat(l3, 0, l3, l2-1).t()
+
+ arma::mat hiddenLayer, outputLayer;
+
+ // Compute activations of the hidden and output layers.
+ hiddenLayer = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+ arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+
+ outputLayer = Sigmoid(parameters.submat(l1, 0, l3-1, l2-1).t() * hiddenLayer +
+ arma::repmat(parameters.submat(l3, 0, l3, l2-1).t(), 1, data.n_cols));
+
+ arma::mat rhoCap, diff;
+
+ // Average activations of the hidden layer.
+ rhoCap = arma::sum(hiddenLayer, 1) / data.n_cols;
+ // Difference between the reconstructed data and the original data.
+ diff = outputLayer - data;
+
+ double wL2SquaredNorm;
+
+ // Calculate squared L2-norms of w1 and w2.
+ wL2SquaredNorm = arma::accu(parameters.submat(0, 0, l3-1, l2-1) %
+ parameters.submat(0, 0, l3-1, l2-1));
+
+ double sumOfSquaresError, weightDecay, klDivergence, cost;
+
+ // Calculate the reconstruction error, the regularization cost and the KL
+ // divergence cost terms. 'sumOfSquaresError' is the average squared l2-norm
+ // of the reconstructed data difference. 'weightDecay' is the squared l2-norm
+ // of the weights w1 and w2. 'klDivergence' is the cost of the hidden layer
+ // activations not being low. It is given by the following formula:
+ // KL = sum_over_hSize(rho*log(rho/rhoCaq) + (1-rho)*log((1-rho)/(1-rhoCap)))
+ sumOfSquaresError = 0.5 * arma::accu(diff % diff) / data.n_cols;
+ weightDecay = 0.5 * lambda * wL2SquaredNorm;
+ klDivergence = beta * arma::accu(rho * arma::log(rho / rhoCap) + (1 - rho) *
+ arma::log((1 - rho) / (1 - rhoCap)));
+
+ // The cost is the sum of the terms calculated above.
+ cost = sumOfSquaresError + weightDecay + klDivergence;
+
+ return cost;
+}
+
+/** Calculates and stores the gradient values given a set of parameters.
+ */
+void SparseAutoencoderFunction::Gradient(const arma::mat& parameters,
+ arma::mat& gradient) const
+{
+ // Performs a feedforward pass of the neural network, and computes the
+ // activations of the output layer as in the Evaluate() method. It uses the
+ // Backpropagation algorithm to calculate the delta values at each layer,
+ // except for the input layer. The delta values are then used with input layer
+ // and hidden layer activations to get the parameter gradients.
+
+ // Compute the limits for the parameters w1, w2, b1 and b2.
+ const size_t l1 = hiddenSize;
+ const size_t l2 = visibleSize;
+ const size_t l3 = 2*hiddenSize;
+
+ // w1, w2, b1 and b2 are not extracted separately, 'parameters' is directly
+ // used in their place to avoid copying data. The following representations
+ // are used:
+ // w1 <- parameters.submat(0, 0, l1-1, l2-1)
+ // w2 <- parameters.submat(l1, 0, l3-1, l2-1).t()
+ // b1 <- parameters.submat(0, l2, l1-1, l2)
+ // b2 <- parameters.submat(l3, 0, l3, l2-1).t()
+
+ arma::mat hiddenLayer, outputLayer;
+
+ // Compute activations of the hidden and output layers.
+ hiddenLayer = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+ arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+
+ outputLayer = Sigmoid(parameters.submat(l1, 0, l3-1, l2-1).t() * hiddenLayer +
+ arma::repmat(parameters.submat(l3, 0, l3, l2-1).t(), 1, data.n_cols));
+
+ arma::mat rhoCap, diff;
+
+ // Average activations of the hidden layer.
+ rhoCap = arma::sum(hiddenLayer, 1) / data.n_cols;
+ // Difference between the reconstructed data and the original data.
+ diff = outputLayer - data;
+
+ arma::mat klDivGrad, delOut, delHid;
+
+ // The delta vector for the output layer is given by diff * f'(z), where z is
+ // the preactivation and f is the activation function. The derivative of the
+ // sigmoid function turns out to be f(z) * (1 - f(z)). For every other layer
+ // in the neural network which comes before the output layer, the delta values
+ // are given del_n = w_n' * del_(n+1) * f'(z_n). Since our cost function also
+ // includes the KL divergence term, we adjust for that in the formula below.
+ klDivGrad = beta * (-(rho / rhoCap) + (1 - rho) / (1 - rhoCap));
+ delOut = diff % outputLayer % (1 - outputLayer);
+ delHid = (parameters.submat(l1, 0, l3-1, l2-1) * delOut +
+ arma::repmat(klDivGrad, 1, data.n_cols)) % hiddenLayer % (1-hiddenLayer);
+
+ gradient.zeros(2*hiddenSize + 1, visibleSize + 1);
+
+ // Compute the gradient values using the activations and the delta values. The
+ // formula also accounts for the regularization terms in the objective.
+ // function.
+ gradient.submat(0, 0, l1-1, l2-1) = delHid * data.t() / data.n_cols + lambda *
+ parameters.submat(0, 0, l1-1, l2-1);
+ gradient.submat(l1, 0, l3-1, l2-1) = (delOut * hiddenLayer.t() / data.n_cols +
+ lambda * parameters.submat(l1, 0, l3-1, l2-1).t()).t();
+ gradient.submat(0, l2, l1-1, l2) = arma::sum(delHid, 1) / data.n_cols;
+ gradient.submat(l3, 0, l3, l2-1) = (arma::sum(delOut, 1) / data.n_cols).t();
+}
diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
new file mode 100644
index 0000000..540119e
--- /dev/null
+++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
@@ -0,0 +1,141 @@
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nn {
+
+/**
+ * This is a class for the Sparse Autoencoder objective function. It can be used
+ * to create learning models like Self-Taught Learning, Stacked Autoencoders,
+ * Conditional Random Fields etc.
+ */
+class SparseAutoencoderFunction
+{
+ public:
+
+ SparseAutoencoderFunction(const arma::mat& data,
+ const size_t visibleSize,
+ const size_t hiddenSize,
+ const double lambda = 0.0001,
+ const double beta = 3,
+ const double rho = 0.01);
+
+ //Initializes the parameters of the model to suitable values.
+ const arma::mat InitializeWeights();
+
+ /**
+ * Evaluates the objective function of the Sparse Autoencoder model using the
+ * given parameters. The cost function has terms for the reconstruction
+ * error, regularization cost and the sparsity cost. The objective function
+ * takes a low value when the model is able to reconstruct the data well
+ * using weights which are low in value and when the average activations of
+ * neurons in the hidden layers agrees well with the sparsity parameter 'rho'.
+ *
+ * @param parameters Current values of the model parameters.
+ */
+ double Evaluate(const arma::mat& parameters) const;
+
+ /**
+ * Evaluates the gradient values of the parameters given the current set of
+ * parameters. The function performs a feedforward pass and computes the error
+ * in reconstructing the data points. It then uses the backpropagation
+ * algorithm to compute the gradient values.
+ *
+ * @param parameters Current values of the model parameters.
+ * @param gradient Pointer to matrix where gradient values are stored.
+ */
+ void Gradient(const arma::mat& parameters, arma::mat& gradient) const;
+
+ /**
+ * Returns the elementwise sigmoid of the passed matrix, where the sigmoid
+ * function of a real number 'x' is [1 / (1 + exp(-x))].
+ *
+ * @param x Matrix of real values for which we require the sigmoid activation.
+ */
+ arma::mat Sigmoid(const arma::mat& x) const
+ {
+ return (1.0 / (1 + arma::exp(-x)));
+ }
+
+ //! Return the initial point for the optimization.
+ const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+ //! Sets size of the visible layer.
+ void VisibleSize(const size_t visible)
+ {
+ this->visibleSize = visible;
+ }
+
+ //! Gets size of the visible layer.
+ size_t VisibleSize() const
+ {
+ return visibleSize;
+ }
+
+ //! Sets size of the hidden layer.
+ void HiddenSize(const size_t hidden)
+ {
+ this->hiddenSize = hidden;
+ }
+
+ //! Gets the size of the hidden layer.
+ size_t HiddenSize() const
+ {
+ return hiddenSize;
+ }
+
+ //! Sets the L2-regularization parameter.
+ void Lambda(const double l)
+ {
+ this->lambda = l;
+ }
+
+ //! Gets the L2-regularization parameter.
+ double Lambda() const
+ {
+ return lambda;
+ }
+
+ //! Sets the KL divergence parameter.
+ void Beta(const double b)
+ {
+ this->beta = b;
+ }
+
+ //! Gets the KL divergence parameter.
+ double Beta() const
+ {
+ return beta;
+ }
+
+ //! Sets the sparsity parameter.
+ void Rho(const double r)
+ {
+ this->rho = r;
+ }
+
+ //! Gets the sparsity parameter.
+ double Rho() const
+ {
+ return rho;
+ }
+
+ private:
+
+ //! The matrix of data points.
+ const arma::mat& data;
+ //! Intial parameter vector.
+ arma::mat initialPoint;
+ //! Size of the visible layer.
+ size_t visibleSize;
+ //! Size of the hidden layer.
+ size_t hiddenSize;
+ //! L2-regularization parameter.
+ double lambda;
+ //! KL divergence parameter.
+ double beta;
+ //! Sparsity parameter.
+ double rho;
+};
+
+}; // namespace nn
+}; // namespace mlpack
diff --git a/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp
new file mode 100644
index 0000000..67e1bfe
--- /dev/null
+++ b/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp
@@ -0,0 +1,62 @@
+namespace mlpack {
+namespace nn {
+
+template<template<typename> class OptimizerType>
+SparseAutoencoder<OptimizerType>::SparseAutoencoder(const arma::mat& data,
+ const size_t visibleSize,
+ const size_t hiddenSize,
+ double lambda,
+ double beta,
+ double rho) :
+ visibleSize(visibleSize),
+ hiddenSize(hiddenSize),
+ lambda(lambda),
+ beta(beta),
+ rho(rho)
+{
+ SparseAutoencoderFunction encoderFunction(data, visibleSize, hiddenSize,
+ lambda, beta, rho);
+ OptimizerType<SparseAutoencoderFunction> optimizer(encoderFunction);
+
+ parameters = encoderFunction.GetInitialPoint();
+
+ // Train the model.
+ Timer::Start("sparse_autoencoder_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("sparse_autoencoder_optimization");
+
+ Log::Info << "SparseAutoencoder::SparseAutoencoder(): final objective of "
+ << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+SparseAutoencoder<OptimizerType>::SparseAutoencoder(
+ OptimizerType<SparseAutoencoderFunction> &optimizer) :
+ parameters(optimizer.Function().GetInitialPoint()),
+ visibleSize(optimizer.Function().VisibleSize()),
+ hiddenSize(optimizer.Function().HiddenSize()),
+ lambda(optimizer.Function().Lambda()),
+ beta(optimizer.Function().Beta()),
+ rho(optimizer.Function().Rho())
+{
+ Timer::Start("sparse_autoencoder_optimization");
+ const double out = optimizer.Optimize(parameters);
+ Timer::Stop("sparse_autoencoder_optimization");
+
+ Log::Info << "SparseAutoencoder::SparseAutoencoder(): final objective of "
+ << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+void SparseAutoencoder<OptimizerType>::GetNewFeatures(arma::mat& data,
+ arma::mat& features)
+{
+ const size_t l1 = hiddenSize;
+ const size_t l2 = visibleSize;
+
+ features = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+ arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+}
+
+}; // namespace nn
+}; // namespace mlpack
More information about the mlpack-git
mailing list