[mlpack-svn] r16432 - mlpack/trunk/src/mlpack/methods/sparse_autoencoder

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 16 14:30:27 EDT 2014


Author: rcurtin
Date: Wed Apr 16 14:30:27 2014
New Revision: 16432

Log:
Add sparse autoencoder contribution by Siddharth.  This is the version given in
ticket #345 as Changes3.tar.gz.


Added:
   mlpack/trunk/src/mlpack/methods/sparse_autoencoder/
   mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp
   mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp
   mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
   mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp

Added: mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder.hpp	Wed Apr 16 14:30:27 2014
@@ -0,0 +1,182 @@
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include "sparse_autoencoder_function.hpp"
+
+/**
+  * Sparse Autoencoder is a neural network whose aim to learn compressed
+  * representations of the data, typically for dimensionality reduction, with
+  * a constraint on the activity of the neurons in the network. Sparse
+  * Autoencoders can be stacked together to learn a hierarchy of features, which
+  * provide a better representation of the data for classification. This is a
+  * method used in the recently developed field of Deep Learning. More technical
+  * details about the model can be found on the following webpage: 
+  * http://deeplearning.stanford.edu/wiki/index.php/UFLDL_Tutorial
+  *
+  * An example of how to use the interface is shown below:
+  * @code
+  * arma::mat data; // Data matrix.
+  * const size_t vSize = 64; // Size of visible layer, depends on the data.
+  * const size_t hSize = 25; // Size of hidden layer, depends on requirements.
+  *
+  * // Train the model using default options.
+  * SparseAutoencoder encoder1(data, vSize, hSize);
+  *
+  * const size_t numBasis = 5; // Parameter required for L-BFGS algorithm.
+  * const size_t numIterations = 100; // Maximum number of iterations.
+  *
+  * // Use an instantiated optimizer for the training.
+  * SparseAutoencoderFunction saf(data, vSize, hSize);
+  * L_BFGS<SparseAutoencoderFunction> optimizer(saf, numBasis, numIterations);
+  * SparseAutoencoder<L_BFGS> encoder2(optimizer);
+  *
+  * arma::mat features1, features2; // Matrices for storing new representations.
+  *
+  * // Get new representations from the trained models.
+  * encoder1.GetNewFeatures(data, features1);
+  * encoder2.GetNewFeatures(data, features2);
+  */
+
+namespace mlpack {
+namespace nn {
+
+template<
+  template<typename> class OptimizerType = mlpack::optimization::L_BFGS
+>
+class SparseAutoencoder
+{
+ public:
+
+  /**
+   * Construct the Sparse Autoencoder model with the given training data. This
+   * will train the model. The parameters 'lambda', 'beta' and 'rho' can be set
+   * optionally. Changing these parameters will have an effect on regularization
+   * and sparsity of the model.
+   *
+   * @param data Input data with each column as one example.
+   * @param visibleSize Size of input vector expected at the visible layer.
+   * @param hiddenSize Size of input vector expected at the hidden layer.
+   * @param lambda L2-regularization parameter.
+   * @param beta KL divergence parameter.
+   * @param rho Sparsity parameter.
+   */
+  SparseAutoencoder(const arma::mat& data,
+                    const size_t visibleSize,
+                    const size_t hiddenSize,
+                    const double lambda = 0.0001,
+                    const double beta = 3,
+                    const double rho = 0.01);
+  
+  /**
+   * Construct the Sparse Autoencoder model with the given training data. This
+   * will train the model. This overload takes an already instantiated optimizer
+   * and uses it to train the model. The optimizer should hold an instantiated
+   * SparseAutoencoderFunction object for the function to operate upon. This
+   * option should be preferred when the optimizer options are to be changed.
+   *
+   * @param optimizer Instantiated optimizer with instantiated error function.
+   */
+  SparseAutoencoder(OptimizerType<SparseAutoencoderFunction>& optimizer);
+  
+  /**
+   * Transforms the provided data into a more meaningful and compact
+   * representation. The function basically performs a feedforward computation
+   * using the learned weights, and returns the hidden layer activations.
+   *
+   * @param data Matrix of the provided data.
+   * @param features The hidden layer representation of the provided data.
+   */
+  void GetNewFeatures(arma::mat& data, arma::mat& features);
+  
+  /**
+   * Returns the elementwise sigmoid of the passed matrix, where the sigmoid
+   * function of a real number 'x' is [1 / (1 + exp(-x))].
+   * 
+   * @param x Matrix of real values for which we require the sigmoid activation.
+   */
+  arma::mat Sigmoid(const arma::mat& x) const
+  {
+    return (1.0 / (1 + arma::exp(-x)));
+  }
+
+  //! Sets size of the visible layer.
+  void VisibleSize(const size_t visible)
+  {
+    this->visibleSize = visible;
+  }
+  
+  //! Gets size of the visible layer.
+  size_t VisibleSize() const
+  {
+    return visibleSize;
+  }
+
+  //! Sets size of the hidden layer.
+  void HiddenSize(const size_t hidden)
+  {
+    this->hiddenSize = hidden;
+  }
+
+  //! Gets the size of the hidden layer.  
+  size_t HiddenSize() const
+  {
+    return hiddenSize;
+  }
+  
+  //! Sets the L2-regularization parameter.
+  void Lambda(const double l)
+  {
+    this->lambda = l;
+  }
+  
+  //! Gets the L2-regularization parameter.
+  double Lambda() const
+  {
+    return lambda;
+  }
+  
+  //! Sets the KL divergence parameter.
+  void Beta(const double b)
+  {
+    this->beta = b;
+  }
+  
+  //! Gets the KL divergence parameter.
+  double Beta() const
+  {
+    return beta;
+  }
+  
+  //! Sets the sparsity parameter.
+  void Rho(const double r)
+  {
+    this->rho = r;
+  }
+  
+  //! Gets the sparsity parameter.
+  double Rho() const
+  {
+    return rho;
+  }
+
+ private:
+
+  //! Parameters after optimization.
+  arma::mat parameters;
+  //! Size of the visible layer.
+  size_t visibleSize;
+  //! Size of the hidden layer.
+  size_t hiddenSize;
+  //! L2-regularization parameter.
+  double lambda;
+  //! KL divergence parameter.
+  double beta;
+  //! Sparsity parameter.
+  double rho;
+};
+
+}; // namespace nn
+}; // namespace mlpack
+
+// Include implementation.
+#include "sparse_autoencoder_impl.hpp"

Added: mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.cpp	Wed Apr 16 14:30:27 2014
@@ -0,0 +1,192 @@
+#include "sparse_autoencoder_function.hpp"
+
+using namespace mlpack;
+using namespace mlpack::nn;
+using namespace std;
+
+SparseAutoencoderFunction::SparseAutoencoderFunction(const arma::mat& data,
+                                                     const size_t visibleSize,
+                                                     const size_t hiddenSize,
+                                                     const double lambda,
+                                                     const double beta,
+                                                     const double rho) :
+    data(data),
+    visibleSize(visibleSize),
+    hiddenSize(hiddenSize),
+    lambda(lambda),
+    beta(beta),
+    rho(rho)
+{
+  // Initialize the parameters to suitable values.
+  initialPoint = InitializeWeights();
+}
+
+/** Initializes the parameter weights if the initial point is not passed to the
+  * constructor. The weights w1, w2 are initialized to randomly in the range
+  * [-r, r] where 'r' is decided using the sizes of the visible and hidden 
+  * layers. The biases b1, b2 are initialized to 0.
+  */
+const arma::mat SparseAutoencoderFunction::InitializeWeights()
+{
+  // The module uses a matrix to store the parameters, its structure looks like:
+  //          vSize   1
+  //       |        |  |
+  //  hSize|   w1   |b1|
+  //       |________|__|
+  //       |        |  |
+  //  hSize|   w2'  |  |
+  //       |________|__|
+  //      1|   b2'  |  |
+  //
+  // There are (hiddenSize + 1) empty cells in the matrix, but it is small
+  // compared to the matrix size. The above structure allows for smooth matrix
+  // operations without making the code too ugly.
+
+  arma::mat parameters;
+  parameters.zeros(2*hiddenSize + 1, visibleSize + 1);
+  
+  // Initialize w1 and w2 to random values in the range [0, 1].
+  parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1).randu();
+  
+  // Decide the parameter 'r' depending on the size of the visible and hidden
+  // layers. The formula used is r = sqrt(6) / sqrt(vSize + hSize + 1).
+  const double range = sqrt(6) / sqrt(visibleSize + hiddenSize + 1);
+  
+  //Shift range of w1 and w2 values from [0, 1] to [-r, r].
+  parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1) = 2 * range *
+      (parameters.submat(0, 0, 2*hiddenSize - 1, visibleSize - 1) - 0.5);
+  
+  return parameters;
+}
+
+/** Evaluates the objective function given the parameters. 
+  */
+double SparseAutoencoderFunction::Evaluate(const arma::mat& parameters) const
+{
+  // The objective function is the average squared reconstruction error of the
+  // network. w1 and b1 are the weights and biases associated with the hidden
+  // layer, whereas w2 and b2 are associated with the output layer.
+  // f(w1,w2,b1,b2) = sum((data - sigmoid(w2*sigmoid(w1data + b1) + b2))^2) / 2m
+  // 'm' is the number of training examples.
+  // The cost also takes into account the regularization and KL divergence terms
+  // to control the parameter weights and sparsity of the model respectively.
+
+  // Compute the limits for the parameters w1, w2, b1 and b2.
+  const size_t l1 = hiddenSize;
+  const size_t l2 = visibleSize;
+  const size_t l3 = 2*hiddenSize;
+  
+  // w1, w2, b1 and b2 are not extracted separately, 'parameters' is directly
+  // used in their place to avoid copying data. The following representations
+  // are used:
+  // w1 <- parameters.submat(0, 0, l1-1, l2-1)
+  // w2 <- parameters.submat(l1, 0, l3-1, l2-1).t()
+  // b1 <- parameters.submat(0, l2, l1-1, l2)
+  // b2 <- parameters.submat(l3, 0, l3, l2-1).t()
+  
+  arma::mat hiddenLayer, outputLayer;
+  
+  // Compute activations of the hidden and output layers.
+  hiddenLayer = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+      arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+  
+  outputLayer = Sigmoid(parameters.submat(l1, 0, l3-1, l2-1).t() * hiddenLayer +
+      arma::repmat(parameters.submat(l3, 0, l3, l2-1).t(), 1, data.n_cols));
+  
+  arma::mat rhoCap, diff;
+  
+  // Average activations of the hidden layer.
+  rhoCap = arma::sum(hiddenLayer, 1) / data.n_cols;
+  // Difference between the reconstructed data and the original data.
+  diff = outputLayer - data;
+  
+  double wL2SquaredNorm;
+  
+  // Calculate squared L2-norms of w1 and w2.
+  wL2SquaredNorm = arma::accu(parameters.submat(0, 0, l3-1, l2-1) %
+      parameters.submat(0, 0, l3-1, l2-1));
+  
+  double sumOfSquaresError, weightDecay, klDivergence, cost;
+  
+  // Calculate the reconstruction error, the regularization cost and the KL
+  // divergence cost terms. 'sumOfSquaresError' is the average squared l2-norm
+  // of the reconstructed data difference. 'weightDecay' is the squared l2-norm
+  // of the weights w1 and w2. 'klDivergence' is the cost of the hidden layer
+  // activations not being low. It is given by the following formula:
+  // KL = sum_over_hSize(rho*log(rho/rhoCaq) + (1-rho)*log((1-rho)/(1-rhoCap)))
+  sumOfSquaresError = 0.5 * arma::accu(diff % diff) / data.n_cols;
+  weightDecay = 0.5 * lambda * wL2SquaredNorm;
+  klDivergence = beta * arma::accu(rho * arma::log(rho / rhoCap) + (1 - rho) *
+      arma::log((1 - rho) / (1 - rhoCap)));
+  
+  // The cost is the sum of the terms calculated above.
+  cost = sumOfSquaresError + weightDecay + klDivergence;
+  
+  return cost;
+}
+
+/** Calculates and stores the gradient values given a set of parameters.
+  */
+void SparseAutoencoderFunction::Gradient(const arma::mat& parameters,
+                                         arma::mat& gradient) const
+{
+  // Performs a feedforward pass of the neural network, and computes the
+  // activations of the output layer as in the Evaluate() method. It uses the
+  // Backpropagation algorithm to calculate the delta values at each layer,
+  // except for the input layer. The delta values are then used with input layer
+  // and hidden layer activations to get the parameter gradients.
+
+  // Compute the limits for the parameters w1, w2, b1 and b2.
+  const size_t l1 = hiddenSize;
+  const size_t l2 = visibleSize;
+  const size_t l3 = 2*hiddenSize;
+  
+  // w1, w2, b1 and b2 are not extracted separately, 'parameters' is directly
+  // used in their place to avoid copying data. The following representations
+  // are used:
+  // w1 <- parameters.submat(0, 0, l1-1, l2-1)
+  // w2 <- parameters.submat(l1, 0, l3-1, l2-1).t()
+  // b1 <- parameters.submat(0, l2, l1-1, l2)
+  // b2 <- parameters.submat(l3, 0, l3, l2-1).t()
+  
+  arma::mat hiddenLayer, outputLayer;
+  
+  // Compute activations of the hidden and output layers.
+  hiddenLayer = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+      arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+  
+  outputLayer = Sigmoid(parameters.submat(l1, 0, l3-1, l2-1).t() * hiddenLayer +
+      arma::repmat(parameters.submat(l3, 0, l3, l2-1).t(), 1, data.n_cols));
+  
+  arma::mat rhoCap, diff;
+  
+  // Average activations of the hidden layer.
+  rhoCap = arma::sum(hiddenLayer, 1) / data.n_cols;
+  // Difference between the reconstructed data and the original data.
+  diff = outputLayer - data;
+  
+  arma::mat klDivGrad, delOut, delHid;
+  
+  // The delta vector for the output layer is given by diff * f'(z), where z is
+  // the preactivation and f is the activation function. The derivative of the
+  // sigmoid function turns out to be f(z) * (1 - f(z)). For every other layer
+  // in the neural network which comes before the output layer, the delta values
+  // are given del_n = w_n' * del_(n+1) * f'(z_n). Since our cost function also
+  // includes the KL divergence term, we adjust for that in the formula below.
+  klDivGrad = beta * (-(rho / rhoCap) + (1 - rho) / (1 - rhoCap));
+  delOut = diff % outputLayer % (1 - outputLayer);
+  delHid = (parameters.submat(l1, 0, l3-1, l2-1) * delOut +
+      arma::repmat(klDivGrad, 1, data.n_cols)) % hiddenLayer % (1-hiddenLayer);
+            
+  gradient.zeros(2*hiddenSize + 1, visibleSize + 1);
+  
+  // Compute the gradient values using the activations and the delta values. The
+  // formula also accounts for the regularization terms in the objective.
+  // function.
+  gradient.submat(0, 0, l1-1, l2-1) = delHid * data.t() / data.n_cols + lambda *
+      parameters.submat(0, 0, l1-1, l2-1);
+  gradient.submat(l1, 0, l3-1, l2-1) = (delOut * hiddenLayer.t() / data.n_cols +
+      lambda * parameters.submat(l1, 0, l3-1, l2-1).t()).t();
+  gradient.submat(0, l2, l1-1, l2) = arma::sum(delHid, 1) / data.n_cols;
+  gradient.submat(l3, 0, l3, l2-1) = (arma::sum(delOut, 1) / data.n_cols).t();
+}

Added: mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_function.hpp	Wed Apr 16 14:30:27 2014
@@ -0,0 +1,141 @@
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nn {
+
+/**
+ * This is a class for the Sparse Autoencoder objective function. It can be used
+ * to create learning models like Self-Taught Learning, Stacked Autoencoders,
+ * Conditional Random Fields etc.
+ */
+class SparseAutoencoderFunction
+{
+ public:
+ 
+  SparseAutoencoderFunction(const arma::mat& data,
+                            const size_t visibleSize,
+                            const size_t hiddenSize,
+                            const double lambda = 0.0001,
+                            const double beta = 3,
+                            const double rho = 0.01);
+  
+  //Initializes the parameters of the model to suitable values.
+  const arma::mat InitializeWeights();
+  
+  /**
+   * Evaluates the objective function of the Sparse Autoencoder model using the
+   * given parameters. The cost function has terms for the reconstruction
+   * error, regularization cost and the sparsity cost. The objective function
+   * takes a low value when the model is able to reconstruct the data well
+   * using weights which are low in value and when the average activations of
+   * neurons in the hidden layers agrees well with the sparsity parameter 'rho'.
+   *
+   * @param parameters Current values of the model parameters.
+   */
+  double Evaluate(const arma::mat& parameters) const;
+  
+  /**
+   * Evaluates the gradient values of the parameters given the current set of
+   * parameters. The function performs a feedforward pass and computes the error 
+   * in reconstructing the data points. It then uses the backpropagation 
+   * algorithm to compute the gradient values.
+   *
+   * @param parameters Current values of the model parameters.
+   * @param gradient Pointer to matrix where gradient values are stored.
+   */
+  void Gradient(const arma::mat& parameters, arma::mat& gradient) const;
+  
+  /**
+   * Returns the elementwise sigmoid of the passed matrix, where the sigmoid
+   * function of a real number 'x' is [1 / (1 + exp(-x))].
+   * 
+   * @param x Matrix of real values for which we require the sigmoid activation.
+   */
+  arma::mat Sigmoid(const arma::mat& x) const
+  {
+    return (1.0 / (1 + arma::exp(-x)));
+  }
+  
+  //! Return the initial point for the optimization.
+  const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+  //! Sets size of the visible layer.
+  void VisibleSize(const size_t visible)
+  {
+    this->visibleSize = visible;
+  }
+  
+  //! Gets size of the visible layer.
+  size_t VisibleSize() const
+  {
+    return visibleSize;
+  }
+
+  //! Sets size of the hidden layer.
+  void HiddenSize(const size_t hidden)
+  {
+    this->hiddenSize = hidden;
+  }
+
+  //! Gets the size of the hidden layer.  
+  size_t HiddenSize() const
+  {
+    return hiddenSize;
+  }
+  
+  //! Sets the L2-regularization parameter.
+  void Lambda(const double l)
+  {
+    this->lambda = l;
+  }
+  
+  //! Gets the L2-regularization parameter.
+  double Lambda() const
+  {
+    return lambda;
+  }
+  
+  //! Sets the KL divergence parameter.
+  void Beta(const double b)
+  {
+    this->beta = b;
+  }
+  
+  //! Gets the KL divergence parameter.
+  double Beta() const
+  {
+    return beta;
+  }
+  
+  //! Sets the sparsity parameter.
+  void Rho(const double r)
+  {
+    this->rho = r;
+  }
+  
+  //! Gets the sparsity parameter.
+  double Rho() const
+  {
+    return rho;
+  }
+
+ private:
+ 
+  //! The matrix of data points.
+  const arma::mat& data;
+  //! Intial parameter vector.
+  arma::mat initialPoint;
+  //! Size of the visible layer.
+  size_t visibleSize;
+  //! Size of the hidden layer.
+  size_t hiddenSize;
+  //! L2-regularization parameter.
+  double lambda;
+  //! KL divergence parameter.
+  double beta;
+  //! Sparsity parameter.
+  double rho;
+};
+
+}; // namespace nn
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/sparse_autoencoder/sparse_autoencoder_impl.hpp	Wed Apr 16 14:30:27 2014
@@ -0,0 +1,62 @@
+namespace mlpack {
+namespace nn {
+
+template<template<typename> class OptimizerType>
+SparseAutoencoder<OptimizerType>::SparseAutoencoder(const arma::mat& data,
+                                                    const size_t visibleSize,
+                                                    const size_t hiddenSize,
+                                                    double lambda,
+                                                    double beta,
+                                                    double rho) :
+    visibleSize(visibleSize),
+    hiddenSize(hiddenSize),
+    lambda(lambda),
+    beta(beta),
+    rho(rho)
+{
+  SparseAutoencoderFunction encoderFunction(data, visibleSize, hiddenSize,
+                                            lambda, beta, rho);
+  OptimizerType<SparseAutoencoderFunction> optimizer(encoderFunction);
+  
+  parameters = encoderFunction.GetInitialPoint();
+  
+  // Train the model.
+  Timer::Start("sparse_autoencoder_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("sparse_autoencoder_optimization");
+
+  Log::Info << "SparseAutoencoder::SparseAutoencoder(): final objective of "
+      << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+SparseAutoencoder<OptimizerType>::SparseAutoencoder(
+    OptimizerType<SparseAutoencoderFunction> &optimizer) :
+    parameters(optimizer.Function().GetInitialPoint()),
+    visibleSize(optimizer.Function().VisibleSize()),
+    hiddenSize(optimizer.Function().HiddenSize()),
+    lambda(optimizer.Function().Lambda()),
+    beta(optimizer.Function().Beta()),
+    rho(optimizer.Function().Rho())
+{
+  Timer::Start("sparse_autoencoder_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("sparse_autoencoder_optimization");
+
+  Log::Info << "SparseAutoencoder::SparseAutoencoder(): final objective of "
+      << "trained model is " << out << "." << std::endl;
+}
+
+template<template<typename> class OptimizerType>
+void SparseAutoencoder<OptimizerType>::GetNewFeatures(arma::mat& data,
+                                                      arma::mat& features)
+{
+  const size_t l1 = hiddenSize;
+  const size_t l2 = visibleSize;
+
+  features = Sigmoid(parameters.submat(0, 0, l1-1, l2-1) * data +
+      arma::repmat(parameters.submat(0, l2, l1-1, l2), 1, data.n_cols));
+}
+
+}; // namespace nn
+}; // namespace mlpack



More information about the mlpack-svn mailing list