[mlpack-git] master: Refactor LSTM layer for new network API. (d77572e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:57 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3

>---------------------------------------------------------------

commit d77572e5845a3643cc54b6c1a94efad18e144d86
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Mon Nov 9 17:42:15 2015 +0100

    Refactor LSTM layer for new network API.


>---------------------------------------------------------------

d77572e5845a3643cc54b6c1a94efad18e144d86
 src/mlpack/methods/ann/layer/lstm_layer.hpp | 449 ++++++++++++----------------
 1 file changed, 194 insertions(+), 255 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/lstm_layer.hpp b/src/mlpack/methods/ann/layer/lstm_layer.hpp
index b452ddc..327976b 100644
--- a/src/mlpack/methods/ann/layer/lstm_layer.hpp
+++ b/src/mlpack/methods/ann/layer/lstm_layer.hpp
@@ -10,10 +10,8 @@
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
-#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
-#include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
-#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/optimizer/irpropp.hpp>
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -25,21 +23,27 @@ namespace ann /** Artificial Neural Network. */ {
  * for the gates and cells and also of the type of the function used to
  * initialize and update the peephole weights.
  *
+ * @tparam OptimizerType Type of the optimizer used to update the weights.
  * @tparam GateActivationFunction Activation function used for the gates.
  * @tparam StateActivationFunction Activation function used for the state.
  * @tparam OutputActivationFunction Activation function used for the output.
  * @tparam WeightInitRule Rule used to initialize the weight matrix.
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+ * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ * @tparam PeepholeDataType Type of the peephole data (weights, derivatives and
+ *         gradients).
  */
 template <
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class GateActivationFunction = LogisticFunction,
     class StateActivationFunction = TanhFunction,
     class OutputActivationFunction = TanhFunction,
-    class WeightInitRule = NguyenWidrowInitialization,
-    typename OptimizerType = iRPROPp<>,
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
+    class WeightInitRule = RandomInitialization,
+    typename InputDataType = arma::mat,
+    typename OutputDataType = arma::mat,
+    typename PeepholeDataType = arma::cube
 >
 class LSTMLayer
 {
@@ -47,117 +51,96 @@ class LSTMLayer
   /**
    * Create the LSTMLayer object using the specified parameters.
    *
-   * @param layerSize The number of memory cells.
-   * @param layerSize The length of the input sequence.
+   * @param outSize The number of output units.
    * @param peepholes The flag used to indicate if peephole connections should
-   * be used (Default: true).
-   * @param WeightInitRule The weight initialize rule used to initialize the
-   * peephole connection matrix.
+   *        be used (Default: false).
+   * @param WeightInitRule The weight initialization rule used to initialize the
+   *        weight matrix.
    */
-  LSTMLayer(const size_t layerSize,
-            const size_t seqLen = 1,
+  LSTMLayer(const size_t outSize,
             const bool peepholes = false,
             WeightInitRule weightInitRule = WeightInitRule()) :
-      inputActivations(arma::zeros<VecType>(layerSize * 4)),
-      layerSize(layerSize),
-      seqLen(seqLen),
-      inGate(arma::zeros<MatType>(layerSize, seqLen)),
-      inGateAct(arma::zeros<MatType>(layerSize, seqLen)),
-      inGateError(arma::zeros<MatType>(layerSize, seqLen)),
-      outGate(arma::zeros<MatType>(layerSize, seqLen)),
-      outGateAct(arma::zeros<MatType>(layerSize, seqLen)),
-      outGateError(arma::zeros<MatType>(layerSize, seqLen)),
-      forgetGate(arma::zeros<MatType>(layerSize, seqLen)),
-      forgetGateAct(arma::zeros<MatType>(layerSize, seqLen)),
-      forgetGateError(arma::zeros<MatType>(layerSize, seqLen)),
-      state(arma::zeros<MatType>(layerSize, seqLen)),
-      stateError(arma::zeros<MatType>(layerSize, seqLen)),
-      cellAct(arma::zeros<MatType>(layerSize, seqLen)),
+      outSize(outSize),
+      peepholes(peepholes),
+      seqLen(1),
       offset(0),
-      peepholes(peepholes)
+      optimizer(new OptimizerType<LSTMLayer<OptimizerType,
+                                            GateActivationFunction,
+                                            StateActivationFunction,
+                                            OutputActivationFunction,
+                                            WeightInitRule,
+                                            InputDataType,
+                                            OutputDataType,
+                                            PeepholeDataType>,
+                                            PeepholeDataType>(*this)),
+      ownsOptimizer(true)
   {
     if (peepholes)
     {
-      weightInitRule.Initialize(inGatePeepholeWeights, layerSize, 1);
-      inGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
-      inGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
-      new OptimizerType(1, layerSize));
-
-      weightInitRule.Initialize(forgetGatePeepholeWeights, layerSize, 1);
-      forgetGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
-      forgetGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
-      new OptimizerType(1, layerSize));
-
-      weightInitRule.Initialize(outGatePeepholeWeights, layerSize, 1);
-      outGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
-      outGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
-      new OptimizerType(1, layerSize));
+      weightInitRule.Initialize(peepholeWeights, outSize, 1, 3);
+      peepholeDerivatives = PeepholeDataType(outSize, 1, 3);
+      peepholeGradient = PeepholeDataType(outSize, 1, 3);
     }
   }
 
+  /**
+   * Delete the LSTMLayer object and its optimizer.
+   */
   ~LSTMLayer()
   {
-    OptimizerType* inGatePeepholePtr = inGatePeepholeOptimizer.release();
-    delete inGatePeepholePtr;
-
-    OptimizerType* forgetGatePeepholePtr = forgetGatePeepholeOptimizer.release();
-    delete forgetGatePeepholePtr;
-
-    OptimizerType* outGatePeepholePtr = outGatePeepholeOptimizer.release();
-    delete outGatePeepholePtr;
+    if (ownsOptimizer)
+      delete optimizer;
   }
 
   /**
    * Ordinary feed forward pass of a neural network, evaluating the function
    * f(x) by propagating the activity forward through f.
    *
-   * @param inputActivation Input data used for evaluating the specified
-   * activity function.
-   * @param outputActivation Datatype to store the resulting output activation.
+   * @param input Input data used for evaluating the specified function.
+   * @param output Resulting output activation.
    */
-  void FeedForward(const VecType& inputActivation, VecType& outputActivation)
+  template<typename eT>
+  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output)
   {
     if (inGate.n_cols < seqLen)
     {
-      inGate = arma::zeros<MatType>(layerSize, seqLen);
-      inGateAct = arma::zeros<MatType>(layerSize, seqLen);
-      inGateError = arma::zeros<MatType>(layerSize, seqLen);
-      outGate = arma::zeros<MatType>(layerSize, seqLen);
-      outGateAct = arma::zeros<MatType>(layerSize, seqLen);
-      outGateError = arma::zeros<MatType>(layerSize, seqLen);
-      forgetGate = arma::zeros<MatType>(layerSize, seqLen);
-      forgetGateAct = arma::zeros<MatType>(layerSize, seqLen);
-      forgetGateError = arma::zeros<MatType>(layerSize, seqLen);
-      state = arma::zeros<MatType>(layerSize, seqLen);
-      stateError = arma::zeros<MatType>(layerSize, seqLen);
-      cellAct = arma::zeros<MatType>(layerSize, seqLen);
+      inGate = arma::zeros<InputDataType>(outSize, seqLen);
+      inGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+      inGateError = arma::zeros<InputDataType>(outSize, seqLen);
+      outGate = arma::zeros<InputDataType>(outSize, seqLen);
+      outGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+      outGateError = arma::zeros<InputDataType>(outSize, seqLen);
+      forgetGate = arma::zeros<InputDataType>(outSize, seqLen);
+      forgetGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+      forgetGateError = arma::zeros<InputDataType>(outSize, seqLen);
+      state = arma::zeros<InputDataType>(outSize, seqLen);
+      stateError = arma::zeros<InputDataType>(outSize, seqLen);
+      cellAct = arma::zeros<InputDataType>(outSize, seqLen);
     }
 
     // Split up the inputactivation into the 3 parts (inGate, forgetGate,
     // outGate).
-    inGate.col(offset) = inputActivation.subvec(0, layerSize - 1);
-    forgetGate.col(offset) = inputActivation.subvec(
-        layerSize, (layerSize * 2) - 1);
-    outGate.col(offset) = inputActivation.subvec(
-        layerSize * 3, (layerSize * 4) - 1);
+    inGate.col(offset) = input.submat(0, 0, outSize - 1, 0);
+    forgetGate.col(offset) = input.submat(outSize, 0, (outSize * 2) - 1, 0);
+    outGate.col(offset) = input.submat(outSize * 3, 0, (outSize * 4) - 1, 0);
 
     if (peepholes && offset > 0)
     {
-      inGate.col(offset) += inGatePeepholeWeights % state.col(offset - 1);
-      forgetGate.col(offset) += forgetGatePeepholeWeights %
+      inGate.col(offset) += peepholeWeights.slice(0) % state.col(offset - 1);
+      forgetGate.col(offset) += peepholeWeights.slice(1) %
           state.col(offset - 1);
     }
 
-    VecType inGateActivation = inGateAct.unsafe_col(offset);
+    arma::Col<eT> inGateActivation = inGateAct.unsafe_col(offset);
     GateActivationFunction::fn(inGate.unsafe_col(offset), inGateActivation);
 
-    VecType forgetGateActivation = forgetGateAct.unsafe_col(offset);
+    arma::Col<eT> forgetGateActivation = forgetGateAct.unsafe_col(offset);
     GateActivationFunction::fn(forgetGate.unsafe_col(offset),
         forgetGateActivation);
 
-    VecType cellActivation = cellAct.unsafe_col(offset);
-    StateActivationFunction::fn(inputActivation.subvec(layerSize * 2,
-        (layerSize * 3) - 1), cellActivation);
+    arma::Col<eT> cellActivation = cellAct.unsafe_col(offset);
+    StateActivationFunction::fn(input.submat(outSize * 2, 0,
+        (outSize * 3) - 1, 0), cellActivation);
 
     state.col(offset) = inGateAct.col(offset) % cellActivation;
 
@@ -165,13 +148,13 @@ class LSTMLayer
       state.col(offset) += forgetGateAct.col(offset) % state.col(offset - 1);
 
     if (peepholes)
-      outGate.col(offset) += outGatePeepholeWeights % state.col(offset);
+      outGate.col(offset) += peepholeWeights.slice(2) % state.col(offset);
 
-    VecType outGateActivation = outGateAct.unsafe_col(offset);
+    arma::Col<eT> outGateActivation = outGateAct.unsafe_col(offset);
     GateActivationFunction::fn(outGate.unsafe_col(offset), outGateActivation);
 
-    OutputActivationFunction::fn(state.unsafe_col(offset), outputActivation);
-    outputActivation = outGateAct.col(offset) % outputActivation;
+    OutputActivationFunction::fn(state.unsafe_col(offset), output);
+    output = outGateAct.col(offset) % output;
 
     offset = (offset + 1) % seqLen;
   }
@@ -181,30 +164,30 @@ class LSTMLayer
    * f(x) by propagating x backwards trough f. Using the results from the feed
    * forward pass.
    *
-   * @param inputActivation Input data used for calculating the function f(x).
-   * @param error The backpropagated error.
-   * @param delta The calculating delta using the partial derivative of the
-   * error with respect to a weight.
+   * @param input The propagated input activation.
+   * @param gy The backpropagated error.
+   * @param g The calculated gradient.
    */
-  void FeedBackward(const VecType& /* unused */,
-                    const VecType& error,
-                    VecType& delta)
+  template<typename InputType, typename eT>
+  void Backward(const InputType& /* unused */,
+                const arma::Mat<eT>& gy,
+                arma::Mat<eT>& g)
   {
     size_t queryOffset = seqLen - offset - 1;
 
-    VecType outGateDerivative;
+    arma::Col<eT> outGateDerivative;
     GateActivationFunction::deriv(outGateAct.unsafe_col(queryOffset),
         outGateDerivative);
 
-    VecType stateActivation;
+    arma::Col<eT> stateActivation;
     StateActivationFunction::fn(state.unsafe_col(queryOffset), stateActivation);
 
-    outGateError.col(queryOffset) = outGateDerivative % error % stateActivation;
+    outGateError.col(queryOffset) = outGateDerivative % gy % stateActivation;
 
-    VecType stateDerivative;
+    arma::Col<eT> stateDerivative;
     StateActivationFunction::deriv(stateActivation, stateDerivative);
 
-    stateError.col(queryOffset) = error % outGateAct.col(queryOffset) %
+    stateError.col(queryOffset) = gy % outGateAct.col(queryOffset) %
         stateDerivative;
 
     if (queryOffset < (seqLen - 1))
@@ -215,27 +198,27 @@ class LSTMLayer
       if (peepholes)
       {
         stateError.col(queryOffset) += inGateError.col(queryOffset + 1) %
-            inGatePeepholeWeights;
+            peepholeWeights.slice(0);
         stateError.col(queryOffset) += forgetGateError.col(queryOffset + 1) %
-            forgetGatePeepholeWeights;
+            peepholeWeights.slice(1);
       }
     }
 
     if (peepholes)
     {
       stateError.col(queryOffset) += outGateError.col(queryOffset) %
-          outGatePeepholeWeights;
+          peepholeWeights.slice(2);
     }
 
-    VecType cellDerivative;
+    arma::Col<eT> cellDerivative;
     StateActivationFunction::deriv(cellAct.col(queryOffset), cellDerivative);
 
-    VecType cellError = inGateAct.col(queryOffset) % cellDerivative %
+    arma::Col<eT> cellError = inGateAct.col(queryOffset) % cellDerivative %
         stateError.col(queryOffset);
 
     if (queryOffset > 0)
     {
-      VecType forgetGateDerivative;
+      arma::Col<eT> forgetGateDerivative;
       GateActivationFunction::deriv(forgetGateAct.col(queryOffset),
           forgetGateDerivative);
 
@@ -243,7 +226,7 @@ class LSTMLayer
           stateError.col(queryOffset) % state.col(queryOffset - 1);
     }
 
-    VecType inGateDerivative;
+    arma::Col<eT> inGateDerivative;
     GateActivationFunction::deriv(inGateAct.col(queryOffset), inGateDerivative);
 
     inGateError.col(queryOffset) = inGateDerivative %
@@ -251,232 +234,188 @@ class LSTMLayer
 
     if (peepholes)
     {
-      outGateDerivative += outGateError.col(queryOffset) %
+      peepholeDerivatives.slice(2) += outGateError.col(queryOffset) %
           state.col(queryOffset);
+
       if (queryOffset > 0)
       {
-        inGatePeepholeDerivatives += inGateError.col(queryOffset) %
+        peepholeDerivatives.slice(0) += inGateError.col(queryOffset) %
             state.col(queryOffset - 1);
-        forgetGatePeepholeDerivatives += forgetGateError.col(queryOffset) %
+        peepholeDerivatives.slice(1) += forgetGateError.col(queryOffset) %
             state.col(queryOffset - 1);
       }
     }
 
-    delta = arma::zeros<VecType>(layerSize * 4);
-    delta.subvec(0, layerSize - 1) = inGateError.col(queryOffset);
-    delta.subvec(layerSize, (layerSize * 2) - 1) =
+    g = arma::zeros<arma::Mat<eT> >(outSize * 4, 1);
+    g.submat(0, 0, outSize - 1, 0) = inGateError.col(queryOffset);
+    g.submat(outSize, 0, (outSize * 2) - 1, 0) =
         forgetGateError.col(queryOffset);
-    delta.subvec(layerSize * 2, (layerSize * 3) - 1) = cellError;
-    delta.subvec(layerSize * 3, (layerSize * 4) - 1) =
+    g.submat(outSize * 2, 0, (outSize * 3) - 1, 0) = cellError;
+    g.submat(outSize * 3, 0, (outSize * 4) - 1, 0) =
         outGateError.col(queryOffset);
 
     offset = (offset + 1) % seqLen;
 
     if (peepholes && offset == 0)
     {
-      inGatePeepholeGradient = (inGatePeepholeWeights.t() *
-          (inGateError.col(queryOffset) % inGatePeepholeDerivatives)) *
-          inGate.col(queryOffset).t();
-
-      forgetGatePeepholeGradient = (forgetGatePeepholeWeights.t() *
-          (forgetGateError.col(queryOffset) % forgetGatePeepholeDerivatives)) *
-          forgetGate.col(queryOffset).t();
-
-      outGatePeepholeGradient = (outGatePeepholeWeights.t() *
-          (outGateError.col(queryOffset) % outGatePeepholeDerivatives)) *
-          outGate.col(queryOffset).t();
-
-      inGatePeepholeOptimizer->UpdateWeights(inGatePeepholeWeights,
-          inGatePeepholeGradient.t(), 0);
-
-      forgetGatePeepholeOptimizer->UpdateWeights(forgetGatePeepholeWeights,
-          forgetGatePeepholeGradient.t(), 0);
-
-      outGatePeepholeOptimizer->UpdateWeights(outGatePeepholeWeights,
-          outGatePeepholeGradient.t(), 0);
-
-      inGatePeepholeDerivatives.zeros();
-      forgetGatePeepholeDerivatives.zeros();
-      outGatePeepholeDerivatives.zeros();
+      peepholeGradient.slice(0) = arma::trans((peepholeWeights.slice(0).t() *
+          (inGateError.col(queryOffset) % peepholeDerivatives.slice(0))) *
+          inGate.col(queryOffset).t());
+
+      peepholeGradient.slice(1) = arma::trans((peepholeWeights.slice(1).t() *
+          (forgetGateError.col(queryOffset) % peepholeDerivatives.slice(1))) *
+          forgetGate.col(queryOffset).t());
+
+      peepholeGradient.slice(2) = arma::trans((peepholeWeights.slice(2).t() *
+          (outGateError.col(queryOffset) % peepholeDerivatives.slice(2))) *
+          outGate.col(queryOffset).t());
+
+      optimizer->Update();
+      optimizer->Optimize();
+      optimizer->Reset();
+      peepholeDerivatives.zeros();
     }
   }
 
-  //! Get the input activations.
-  const VecType& InputActivation() const { return inputActivations; }
-  //! Modify the input activations.
-  VecType& InputActivation() { return inputActivations; }
-
-  //! Get input size.
-  size_t InputSize() const { return layerSize * 4; }
+  //! Get the peephole weights.
+  PeepholeDataType& Weights() const { return peepholeWeights; }
+  //! Modify the peephole weights.
+  PeepholeDataType& Weights() { return peepholeWeights; }
 
-  //! Get output size.
-  size_t OutputSize() const { return layerSize; }
-  //! Modify the output size.
-  size_t& OutputSize() { return layerSize; }
+  //! Get the input parameter.
+  InputDataType& InputParameter() const {return inputParameter; }
+  //! Modify the input parameter.
+  InputDataType& InputParameter() { return inputParameter; }
 
-  //! Get the number of output maps.
-  size_t OutputMaps() const { return 1; }
+  //! Get the output parameter.
+  OutputDataType& OutputParameter() const {return outputParameter; }
+  //! Modify the output parameter.
+  OutputDataType& OutputParameter() { return outputParameter; }
 
-  //! Get the number of layer slices.
-  size_t LayerSlices() const { return 1; }
-
-  //! Get the number of layer rows.
-  size_t LayerRows() const { return layerSize; }
-
-  //! Get the number of layer columns.
-  size_t LayerCols() const { return 1; }
-
-  //! Get the detla.
-  VecType& Delta() const { return delta; }
+  //! Get the delta.
+  OutputDataType& Delta() const {return delta; }
   //! Modify the delta.
-  VecType& Delta() { return delta; }
+  OutputDataType& Delta() { return delta; }
+
+  //! Get the peephole gradient.
+  PeepholeDataType& Gradient() const {return peepholeGradient; }
+  //! Modify the peephole gradient.
+  PeepholeDataType& Gradient() { return peepholeGradient; }
 
   //! Get the sequence length.
   size_t SeqLen() const { return seqLen; }
   //! Modify the sequence length.
   size_t& SeqLen() { return seqLen; }
 
-  //! Get the InGate peephole weights..
-  MatType& InGatePeepholeWeights() const { return inGatePeepholeWeights; }
-  //! Modify the InGate peephole weights..
-  MatType& InGatePeepholeWeights() { return inGatePeepholeWeights; }
+ private:
+  //! Locally-stored number of output units.
+  const size_t outSize;
 
-  //! Get the InGate peephole weights..
-  MatType& ForgetGatePeepholeWeights() const {
-    return forgetGatePeepholeWeights; }
-  //! Modify the InGate peephole weights..
-  MatType& ForgetGatePeepholeWeights() { return forgetGatePeepholeWeights; }
+  //! Locally-stored peephole indication flag.
+  const bool peepholes;
 
-  //! Get the InGate peephole weights..
-  MatType& OutGatePeepholeWeights() const { return outGatePeepholeWeights; }
-  //! Modify the InGate peephole weights..
-  MatType& OutGatePeepholeWeights() { return outGatePeepholeWeights; }
+  //! Locally-stored length of the the input sequence.
+  size_t seqLen;
 
-  //! The the value of the deterministic parameter.
-  bool Deterministic() const {return deterministic; }
-  //! Modify the value of the deterministic parameter.
-  bool& Deterministic() {return deterministic; }
+  //! Locally-stored sequence offset.
+  size_t offset;
 
- private:
-  //! Locally-stored input activation object.
-  VecType inputActivations;
+  //! Locally-stored pointer to the optimzer object.
+  OptimizerType<LSTMLayer<OptimizerType,
+                          GateActivationFunction,
+                          StateActivationFunction,
+                          OutputActivationFunction,
+                          WeightInitRule,
+                          InputDataType,
+                          OutputDataType,
+                          PeepholeDataType>, PeepholeDataType>* optimizer;
+
+  //! Parameter that indicates if the class owns a optimizer object.
+  bool ownsOptimizer;
 
   //! Locally-stored delta object.
-  VecType delta;
+  OutputDataType delta;
 
-  //! Locally-stored number of memory cells.
-  size_t layerSize;
+  //! Locally-stored gradient object.
+  OutputDataType gradient;
 
-  //! Locally-stored length of the the input sequence.
-  size_t seqLen;
+  //! Locally-stored input parameter object.
+  InputDataType inputParameter;
+
+  //! Locally-stored output parameter object.
+  OutputDataType outputParameter;
 
   //! Locally-stored ingate object.
-  MatType inGate;
+  InputDataType inGate;
 
   //! Locally-stored ingate activation object.
-  MatType inGateAct;
+  InputDataType inGateAct;
 
   //! Locally-stored ingate error object.
-  MatType inGateError;
+  InputDataType inGateError;
 
   //! Locally-stored outgate object.
-  MatType outGate;
+  InputDataType outGate;
 
   //! Locally-stored outgate activation object.
-  MatType outGateAct;
+  InputDataType outGateAct;
 
   //! Locally-stored outgate error object.
-  MatType outGateError;
+  InputDataType outGateError;
 
   //! Locally-stored forget object.
-  MatType forgetGate;
+  InputDataType forgetGate;
 
   //! Locally-stored forget activation object.
-  MatType forgetGateAct;
+  InputDataType forgetGateAct;
 
   //! Locally-stored forget error object.
-  MatType forgetGateError;
+  InputDataType forgetGateError;
 
   //! Locally-stored state object.
-  MatType state;
+  InputDataType state;
 
   //! Locally-stored state erro object.
-  MatType stateError;
+  InputDataType stateError;
 
   //! Locally-stored cell activation object.
-  MatType cellAct;
-
-  //! Locally-stored sequence offset.
-  size_t offset;
-
-  //! Locally-stored peephole indication flag.
-  const bool peepholes;
+  InputDataType cellAct;
 
-  //! Locally-stored peephole ingate weights.
-  MatType inGatePeepholeWeights;
+  //! Locally-stored peephole weight object.
+  PeepholeDataType peepholeWeights;
 
-  //! Locally-stored peephole ingate derivatives.
-  VecType inGatePeepholeDerivatives;
+  //! Locally-stored derivatives object.
+  PeepholeDataType peepholeDerivatives;
 
-  //! Locally-stored peephole ingate gradients.
-  MatType inGatePeepholeGradient;
-
-  //! Locally-stored ingate peephole optimzer object.
-  std::unique_ptr<OptimizerType> inGatePeepholeOptimizer;
-
-  //! Locally-stored peephole forget weights.
-  MatType forgetGatePeepholeWeights;
-
-  //! Locally-stored peephole forget derivatives.
-  VecType forgetGatePeepholeDerivatives;
-
-  //! Locally-stored peephole forget gradients.
-  MatType forgetGatePeepholeGradient;
-
-  //! Locally-stored forget peephole optimzer object.
-  std::unique_ptr<OptimizerType> forgetGatePeepholeOptimizer;
-
-  //! Locally-stored peephole outgate weights.
-  MatType outGatePeepholeWeights;
-
-  //! Locally-stored peephole outgate derivatives.
-  VecType outGatePeepholeDerivatives;
-
-  //! Locally-stored peephole outgate gradients.
-  MatType outGatePeepholeGradient;
-
-  //! Locally-stored outgate peephole optimzer object.
-  std::unique_ptr<OptimizerType> outGatePeepholeOptimizer;
-
-  //! Locally-stored deterministic parameter.
-  bool deterministic;
+  //! Locally-stored peephole gradient object.
+  PeepholeDataType peepholeGradient;
 }; // class LSTMLayer
 
-//! Layer traits for the bias layer.
+//! Layer traits for the lstm layer.
 template<
+    template<typename, typename> class OptimizerType,
     class GateActivationFunction,
     class StateActivationFunction,
     class OutputActivationFunction,
     class WeightInitRule,
-    typename OptimizerType,
-    typename MatType,
-    typename VecType
->
-class LayerTraits<
-    LSTMLayer<GateActivationFunction,
-    StateActivationFunction,
-    OutputActivationFunction,
-    WeightInitRule,
-    OptimizerType,
-    MatType,
-    VecType>
+    typename InputDataType,
+    typename OutputDataType,
+    typename PeepholeDataType
 >
+class LayerTraits<LSTMLayer<OptimizerType,
+                            GateActivationFunction,
+                            StateActivationFunction,
+                            WeightInitRule,
+                            InputDataType,
+                            OutputDataType,
+                            PeepholeDataType> >
 {
  public:
   static const bool IsBinary = false;
   static const bool IsOutputLayer = false;
   static const bool IsBiasLayer = false;
   static const bool IsLSTMLayer = true;
+  static const bool IsConnection = false;
 };
 
 }; // namespace ann



More information about the mlpack-git mailing list