[mlpack-git] master: Add implementation of the LSTMLayer class. (48b6278)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:09:49 EST 2015

Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40


commit 48b62783bed8f79db644392bfe1e6af3048e2287
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Jan 2 18:17:27 2015 +0100

    Add implementation of the LSTMLayer class.


 src/mlpack/methods/ann/layer/lstm_layer.hpp | 394 ++++++++++++++++++++++++++++
 1 file changed, 394 insertions(+)

diff --git a/src/mlpack/methods/ann/layer/lstm_layer.hpp b/src/mlpack/methods/ann/layer/lstm_layer.hpp
new file mode 100644
index 0000000..02d5c42
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/lstm_layer.hpp
@@ -0,0 +1,394 @@
+ * @file lstm_layer.hpp
+ * @author Marcus Edel
+ *
+ * Definition of the LSTMLayer class, which implements a lstm network
+ * layer.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/ann/layer/layer_traits.hpp>
+#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
+#include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
+#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
+#include <mlpack/methods/ann/optimizer/rpropp.hpp>
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+ * An implementation of a lstm network layer.
+ *
+ * This class allows specification of the type of the activation functions used
+ * for the gates and cells and also of the type of the function used to
+ * initialize and update the peephole weights.
+ *
+ * @tparam GateActivationFunction Activation function used for the gates.
+ * @tparam StateActivationFunction Activation function used for the state.
+ * @tparam OutputActivationFunction Activation function used for the output.
+ * @tparam WeightInitRule Rule used to initialize the weight matrix.
+ * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+ */
+template <
+    class GateActivationFunction = LogisticFunction,
+    class StateActivationFunction = TanhFunction,
+    class OutputActivationFunction = TanhFunction,
+    class WeightInitRule = NguyenWidrowInitialization<>,
+    typename OptimizerType = RPROPp<>,
+    typename MatType = arma::mat,
+    typename VecType = arma::colvec
+class LSTMLayer
+ public:
+  /**
+   * Create the LSTMLayer object using the specified parameters.
+   *
+   * @param layerSize The number of memory cells.
+   * @param layerSize The length of the input sequence.
+   * @param peepholes The flag used to indicate if peephole connections should
+   * be used (Default: true).
+   * @param WeightInitRule The weight initialize rule used to initialize the
+   * peephole connection matrix.
+   */
+  LSTMLayer(const size_t layerSize,
+            const size_t seqLen,
+            const bool peepholes = true,
+            WeightInitRule weightInitRule = WeightInitRule()) :
+      inputActivations(arma::zeros<VecType>(layerSize * 4)),
+      layerSize(layerSize),
+      seqLen(seqLen),
+      inGate(arma::zeros<MatType>(layerSize, seqLen)),
+      inGateAct(arma::zeros<MatType>(layerSize, seqLen)),
+      inGateError(arma::zeros<MatType>(layerSize, seqLen)),
+      outGate(arma::zeros<MatType>(layerSize, seqLen)),
+      outGateAct(arma::zeros<MatType>(layerSize, seqLen)),
+      outGateError(arma::zeros<MatType>(layerSize, seqLen)),
+      forgetGate(arma::zeros<MatType>(layerSize, seqLen)),
+      forgetGateAct(arma::zeros<MatType>(layerSize, seqLen)),
+      forgetGateError(arma::zeros<MatType>(layerSize, seqLen)),
+      state(arma::zeros<MatType>(layerSize, seqLen)),
+      stateError(arma::zeros<MatType>(layerSize, seqLen)),
+      cellAct(arma::zeros<MatType>(layerSize, seqLen)),
+      offset(0),
+      peepholes(peepholes)
+  {
+    if (peepholes)
+    {
+      weightInitRule.Initialize(inGatePeepholeWeights, layerSize, 1);
+      inGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
+      inGatePeepholeGradient = arma::zeros<MatType>(layerSize, 1);
+      inGatePeepholeOptimizer = std::auto_ptr<OptimizerType>(
+          new OptimizerType(1, 2));
+      weightInitRule.Initialize(forgetGatePeepholeWeights, layerSize, 1);
+      forgetGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
+      forgetGatePeepholeGradient = arma::zeros<MatType>(layerSize, 1);
+      forgetGatePeepholeOptimizer = std::auto_ptr<OptimizerType>(
+          new OptimizerType(1, 2));
+      weightInitRule.Initialize(outGatePeepholeWeights, layerSize, 1);
+      outGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
+      outGatePeepholeGradient = arma::zeros<MatType>(layerSize, 1);
+      outGatePeepholeOptimizer = std::auto_ptr<OptimizerType>(
+          new OptimizerType(1, 2));
+    }
+  }
+  ~LSTMLayer()
+  {
+    OptimizerType* inGatePeepholePtr = inGatePeepholeOptimizer.release();
+    delete inGatePeepholePtr;
+    OptimizerType* forgetGatePeepholePtr = forgetGatePeepholeOptimizer.release();
+    delete forgetGatePeepholePtr;
+    OptimizerType* outGatePeepholePtr = outGatePeepholeOptimizer.release();
+    delete outGatePeepholePtr;
+  }
+  /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f.
+   *
+   * @param inputActivation Input data used for evaluating the specified
+   * activity function.
+   * @param outputActivation Datatype to store the resulting output activation.
+   */
+  void FeedForward(const VecType& inputActivation, VecType& outputActivation)
+  {
+    // Split up the inputactivation into the 3 parts (inGate, forgetGate,
+    // outGate).
+    inGate.col(offset) = inputActivation.subvec(0, layerSize - 1);
+    forgetGate.col(offset) = inputActivation.subvec(
+        layerSize, (layerSize * 2) - 1);
+    outGate.col(offset) = inputActivation.subvec(
+        layerSize * 3, (layerSize * 4) - 1);
+    if (peepholes && offset > 0)
+    {
+      inGate.col(offset) += inGatePeepholeWeights % state.col(offset - 1);
+      forgetGate.col(offset) += forgetGatePeepholeWeights % state.col(offset);
+    }
+    VecType inGateActivation = inGateAct.unsafe_col(offset);
+    GateActivationFunction::fn(inGate.unsafe_col(offset), inGateActivation);
+    VecType forgetGateActivation = forgetGateAct.unsafe_col(offset);
+    GateActivationFunction::fn(forgetGate.unsafe_col(offset),
+        forgetGateActivation);
+    VecType cellActivation = cellAct.unsafe_col(offset);
+    StateActivationFunction::fn(inputActivation.subvec(layerSize * 2,
+        (layerSize * 3) - 1), cellActivation);
+    state.col(offset) = inGateAct.col(offset) % cellActivation;
+    if (offset > 0)
+      state.col(offset) += forgetGateAct.col(offset) % state.col(offset - 1);
+    if (peepholes)
+      outGate.col(offset) += outGatePeepholeWeights % state.col(offset);
+    VecType outGateActivation = outGateAct.unsafe_col(offset);
+    GateActivationFunction::fn(outGate.unsafe_col(offset), outGateActivation);
+    OutputActivationFunction::fn(state.unsafe_col(offset), outputActivation);
+    outputActivation = outGateAct.col(offset) % outputActivation;
+    offset = (offset + 1) % seqLen;
+  }
+  /**
+   * Ordinary feed backward pass of a neural network, calculating the function
+   * f(x) by propagating x backwards trough f. Using the results from the feed
+   * forward pass.
+   *
+   * @param inputActivation Input data used for calculating the function f(x).
+   * @param error The backpropagated error.
+   * @param delta The calculating delta using the partial derivative of the
+   * error with respect to a weight.
+   */
+  void FeedBackward(const VecType& /* unused */,
+                    const VecType& error,
+                    VecType& delta)
+  {
+    size_t queryOffset = seqLen - offset - 1;
+    VecType outGateDerivative;
+    GateActivationFunction::deriv(outGateAct.unsafe_col(queryOffset),
+        outGateDerivative);
+    VecType stateActivation;
+    StateActivationFunction::fn(state.unsafe_col(queryOffset), stateActivation);
+    outGateError.col(queryOffset) = outGateDerivative % error % stateActivation;
+    VecType stateDerivative;
+    StateActivationFunction::deriv(stateActivation, stateDerivative);
+    stateError.col(queryOffset) = error % outGateAct.col(queryOffset) %
+        stateDerivative;
+    if (queryOffset < (seqLen - 1))
+    {
+      stateError.col(queryOffset) += stateError.col(queryOffset + 1) %
+          forgetGateAct.col(queryOffset + 1);
+      if (peepholes)
+      {
+        stateError.col(queryOffset) += inGateError.col(queryOffset + 1) %
+            inGatePeepholeWeights;
+        stateError.col(queryOffset) += forgetGateError.col(queryOffset + 1) %
+            forgetGatePeepholeWeights;
+      }
+    }
+    if (peepholes)
+    {
+      stateError.col(queryOffset) += outGateError.col(queryOffset) %
+          outGatePeepholeWeights;
+    }
+    VecType cellDerivative;
+    StateActivationFunction::deriv(cellAct.col(queryOffset), cellDerivative);
+    VecType cellError = inGateAct.col(queryOffset) % cellDerivative %
+        stateError.col(queryOffset);
+    if (queryOffset > 0)
+    {
+      VecType forgetGateDerivative;
+      GateActivationFunction::deriv(forgetGateAct.col(queryOffset),
+          forgetGateDerivative);
+      forgetGateError.col(queryOffset) = forgetGateDerivative %
+          stateError.col(queryOffset) % state.col(queryOffset - 1);
+    }
+    VecType inGateDerivative;
+    GateActivationFunction::deriv(inGateAct.col(queryOffset), inGateDerivative);
+    inGateError.col(queryOffset) = inGateDerivative %
+        stateError.col(queryOffset) % cellAct.col(queryOffset);
+    if (peepholes)
+    {
+      outGateDerivative += outGateError.col(queryOffset) %
+          state.col(queryOffset);
+      if (queryOffset > 0)
+      {
+        inGatePeepholeDerivatives += inGateError.col(queryOffset) %
+            state.col(queryOffset - 1);
+        forgetGatePeepholeDerivatives += forgetGateError.col(queryOffset) %
+            state.col(queryOffset - 1);
+      }
+    }
+    delta = arma::zeros<VecType>(layerSize * 4);
+    delta.subvec(0, layerSize - 1) = inGateError.col(queryOffset);
+    delta.subvec(layerSize, (layerSize * 2) - 1) =
+        forgetGateError.col(queryOffset);
+    delta.subvec(layerSize * 2, (layerSize * 3) - 1) = cellError;
+    delta.subvec(layerSize * 3, (layerSize * 4) - 1) =
+        outGateError.col(queryOffset);
+    offset = (offset + 1) % seqLen;
+    if (peepholes && offset == 0)
+    {
+      inGatePeepholeOptimizer->UpdateWeights(inGatePeepholeWeights,
+          inGatePeepholeGradient, 0);
+      forgetGatePeepholeOptimizer->UpdateWeights(forgetGatePeepholeWeights,
+          forgetGatePeepholeGradient, 0);
+      outGatePeepholeOptimizer->UpdateWeights(outGatePeepholeWeights,
+          outGatePeepholeGradient, 0);
+      inGatePeepholeGradient.zeros();
+      forgetGatePeepholeGradient.zeros();
+      outGatePeepholeGradient.zeros();
+      inGatePeepholeDerivatives.zeros();
+      forgetGatePeepholeDerivatives.zeros();
+      outGatePeepholeDerivatives.zeros();
+    }
+  }
+  //! Get the input activations.
+  const VecType& InputActivation() const { return inputActivations; }
+ //  //! Modify the input activations.
+  VecType& InputActivation() { return inputActivations; }
+  //! Get input size.
+  size_t InputSize() const { return layerSize * 4; }
+  //! Get output size.
+  size_t OutputSize() const { return layerSize; }
+  //! Modify the output size.
+  size_t& OutputSize() { return layerSize; }
+  //! Get the detla.
+  VecType& Delta() const { return delta; }
+ //  //! Modify the delta.
+  VecType& Delta() { return delta; }
+ private:
+  //! Locally-stored input activation object.
+  VecType inputActivations;
+  //! Locally-stored delta object.
+  VecType delta;
+  //! Locally-stored number of memory cells.
+  size_t layerSize;
+  //! Locally-stored length of the the input sequence.
+  size_t seqLen;
+  //! Locally-stored ingate object.
+  MatType inGate;
+  //! Locally-stored ingate activation object.
+  MatType inGateAct;
+  //! Locally-stored ingate error object.
+  MatType inGateError;
+  //! Locally-stored outgate object.
+  MatType outGate;
+  //! Locally-stored outgate activation object.
+  MatType outGateAct;
+  //! Locally-stored outgate error object.
+  MatType outGateError;
+  //! Locally-stored forget object.
+  MatType forgetGate;
+  //! Locally-stored forget activation object.
+  MatType forgetGateAct;
+  //! Locally-stored forget error object.
+  MatType forgetGateError;
+  //! Locally-stored state object.
+  MatType state;
+  //! Locally-stored state erro object.
+  MatType stateError;
+  //! Locally-stored cell activation object.
+  MatType cellAct;
+  //! Locally-stored sequence offset.
+  size_t offset;
+  //! Locally-stored peephole indication flag.
+  const bool peepholes;
+  //! Locally-stored peephole ingate weights.
+  MatType inGatePeepholeWeights;
+  //! Locally-stored peephole ingate derivatives.
+  VecType inGatePeepholeDerivatives;
+  //! Locally-stored peephole ingate gradients.
+  MatType inGatePeepholeGradient;
+  //! Locally-stored ingate peephole optimzer object.
+  std::auto_ptr<OptimizerType> inGatePeepholeOptimizer;
+  //! Locally-stored peephole forget weights.
+  MatType forgetGatePeepholeWeights;
+  //! Locally-stored peephole forget derivatives.
+  VecType forgetGatePeepholeDerivatives;
+  //! Locally-stored peephole forget gradients.
+  MatType forgetGatePeepholeGradient;
+  //! Locally-stored forget peephole optimzer object.
+  std::auto_ptr<OptimizerType> forgetGatePeepholeOptimizer;
+  //! Locally-stored peephole outgate weights.
+  MatType outGatePeepholeWeights;
+  //! Locally-stored peephole outgate derivatives.
+  VecType outGatePeepholeDerivatives;
+  //! Locally-stored peephole outgate gradients.
+  MatType outGatePeepholeGradient;
+  //! Locally-stored outgate peephole optimzer object.
+  std::auto_ptr<OptimizerType> outGatePeepholeOptimizer;
+}; // class LSTMLayer
+}; // namespace ann
+}; // namespace mlpack

More information about the mlpack-git mailing list