[mlpack-git] master: Adjust the rnn class and LSTM layer to handle sequences of different lengths. (757c92a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat Mar 7 08:10:13 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3e4a3c8ba42e113e0cdebd73bbfa1f6dea9d7010...757c92a1596ef28f5bc924fbec031fb24b98c781

>---------------------------------------------------------------

commit 757c92a1596ef28f5bc924fbec031fb24b98c781
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Mar 7 14:10:06 2015 +0100

    Adjust the rnn class and LSTM layer to handle sequences of different lengths.


>---------------------------------------------------------------

757c92a1596ef28f5bc924fbec031fb24b98c781
 src/mlpack/core/util/arma_traits.hpp               |  0
 src/mlpack/methods/ann/layer/bias_layer.hpp        |  5 +-
 .../ann/layer/binary_classification_layer.hpp      |  6 +--
 src/mlpack/methods/ann/layer/layer_traits.hpp      |  5 ++
 src/mlpack/methods/ann/layer/lstm_layer.hpp        | 56 +++++++++++++++++++++-
 src/mlpack/methods/ann/rnn.hpp                     | 34 ++++++++++++-
 6 files changed, 96 insertions(+), 10 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/bias_layer.hpp b/src/mlpack/methods/ann/layer/bias_layer.hpp
index dde95dd..d982b65 100644
--- a/src/mlpack/methods/ann/layer/bias_layer.hpp
+++ b/src/mlpack/methods/ann/layer/bias_layer.hpp
@@ -110,8 +110,8 @@ class BiasLayer
 }; // class BiasLayer
 
 //! Layer traits for the bias layer.
-template<>
-class LayerTraits<BiasLayer<> >
+template<typename ActivationFunction, typename MatType, typename VecType>
+class LayerTraits<BiasLayer<ActivationFunction, MatType, VecType> >
 {
  public:
   /**
@@ -120,6 +120,7 @@ class LayerTraits<BiasLayer<> >
   static const bool IsBinary = false;
   static const bool IsOutputLayer = false;
   static const bool IsBiasLayer = true;
+  static const bool IsLSTMLayer = false;
 };
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
index 109895d..ecd6064 100644
--- a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
@@ -66,16 +66,14 @@ class BinaryClassificationLayer
 }; // class BinaryClassificationLayer
 
 //! Layer traits for the binary class classification layer.
-template <
-    typename MatType,
-    typename VecType
->
+template <typename MatType, typename VecType>
 class LayerTraits<BinaryClassificationLayer<MatType, VecType> >
 {
  public:
   static const bool IsBinary = true;
   static const bool IsOutputLayer = true;
   static const bool IsBiasLayer = false;
+  static const bool IsLSTMLayer = false;
 };
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/layer_traits.hpp b/src/mlpack/methods/ann/layer/layer_traits.hpp
index b414b05..52ee1af 100644
--- a/src/mlpack/methods/ann/layer/layer_traits.hpp
+++ b/src/mlpack/methods/ann/layer/layer_traits.hpp
@@ -36,6 +36,11 @@ class LayerTraits
    * This is true if the layer is a bias layer.
    */
   static const bool IsBiasLayer = false;
+
+  /*
+   * This is true if the layer is a LSTM layer.
+   **/
+  static const bool IsLSTMLayer = false;
 };
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/lstm_layer.hpp b/src/mlpack/methods/ann/layer/lstm_layer.hpp
index 02d5c42..bc8dfaf 100644
--- a/src/mlpack/methods/ann/layer/lstm_layer.hpp
+++ b/src/mlpack/methods/ann/layer/lstm_layer.hpp
@@ -55,8 +55,8 @@ class LSTMLayer
    * peephole connection matrix.
    */
   LSTMLayer(const size_t layerSize,
-            const size_t seqLen,
-            const bool peepholes = true,
+            const size_t seqLen = 1,
+            const bool peepholes = false,
             WeightInitRule weightInitRule = WeightInitRule()) :
       inputActivations(arma::zeros<VecType>(layerSize * 4)),
       layerSize(layerSize),
@@ -120,6 +120,22 @@ class LSTMLayer
    */
   void FeedForward(const VecType& inputActivation, VecType& outputActivation)
   {
+    if (inGate.n_cols < seqLen)
+    {
+      inGate = arma::zeros<MatType>(layerSize, seqLen);
+      inGateAct = arma::zeros<MatType>(layerSize, seqLen);
+      inGateError = arma::zeros<MatType>(layerSize, seqLen);
+      outGate = arma::zeros<MatType>(layerSize, seqLen);
+      outGateAct = arma::zeros<MatType>(layerSize, seqLen);
+      outGateError = arma::zeros<MatType>(layerSize, seqLen);
+      forgetGate = arma::zeros<MatType>(layerSize, seqLen);
+      forgetGateAct = arma::zeros<MatType>(layerSize, seqLen);
+      forgetGateError = arma::zeros<MatType>(layerSize, seqLen);
+      state = arma::zeros<MatType>(layerSize, seqLen);
+      stateError = arma::zeros<MatType>(layerSize, seqLen);
+      cellAct = arma::zeros<MatType>(layerSize, seqLen);
+    }
+
     // Split up the inputactivation into the 3 parts (inGate, forgetGate,
     // outGate).
     inGate.col(offset) = inputActivation.subvec(0, layerSize - 1);
@@ -296,6 +312,12 @@ class LSTMLayer
   VecType& Delta() const { return delta; }
  //  //! Modify the delta.
   VecType& Delta() { return delta; }
+
+  //! Get the sequence length.
+  size_t SeqLen() const { return seqLen; }
+  //! Modify the sequence length.
+  size_t& SeqLen() { return seqLen; }
+
  private:
   //! Locally-stored input activation object.
   VecType inputActivations;
@@ -388,6 +410,36 @@ class LSTMLayer
   std::auto_ptr<OptimizerType> outGatePeepholeOptimizer;
 }; // class LSTMLayer
 
+//! Layer traits for the bias layer.
+template<
+    class GateActivationFunction,
+    class StateActivationFunction,
+    class OutputActivationFunction,
+    class WeightInitRule,
+    typename OptimizerType,
+    typename MatType,
+    typename VecType
+>
+class LayerTraits<
+    LSTMLayer<GateActivationFunction,
+    StateActivationFunction,
+    OutputActivationFunction,
+    WeightInitRule,
+    OptimizerType,
+    MatType,
+    VecType>
+>
+{
+ public:
+  /**
+   * If true, then the layer is binary.
+   */
+  static const bool IsBinary = false;
+  static const bool IsOutputLayer = false;
+  static const bool IsBiasLayer = false;
+  static const bool IsLSTMLayer = true;
+};
+
 }; // namespace ann
 }; // namespace mlpack
 
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index a8dcc6d..7d4c16f 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -284,6 +284,9 @@ class RNN
     typename std::enable_if<I < sizeof...(Tp), void>::type
     Reset(std::tuple<Tp...>& t)
     {
+      Parameter<I, typename std::remove_reference<
+          decltype(std::get<I>(t).InputLayer())>::type, Tp...>(t);
+
       std::get<I>(t).OutputLayer().InputActivation().zeros(
           std::get<I>(t).OutputLayer().InputSize());
 
@@ -301,6 +304,31 @@ class RNN
     }
 
     /**
+     * Update the sequence length for a specific layer.
+     *
+     * enable_if (SFINAE) is used to determine if classes passed contains the
+     * SeqLen function.
+     */
+    template<size_t I, typename LayerType, typename... Tp>
+    typename std::enable_if<
+        LayerTraits<LayerType>::IsLSTMLayer == false, void>::type
+    Parameter(std::tuple<Tp...>& /* unused */) { }
+
+    /**
+     * Update the sequence length for a specific layer.
+     *
+     * enable_if (SFINAE) is used to determine if classes passed contains the
+     * SeqLen function.
+     */
+    template<size_t I, typename LayerType, typename... Tp>
+    typename std::enable_if<
+        LayerTraits<LayerType>::IsLSTMLayer == true, void>::type
+    Parameter(std::tuple<Tp...>& t)
+    {
+      std::get<I>(t).InputLayer().SeqLen() = seqLen;
+    }
+
+    /**
      * Run a single iteration of the feed forward algorithm, using the given
      * input and target vector, updating the resulting error into the error
      * vector.
@@ -448,7 +476,9 @@ class RNN
 
       // Update the recurrent delta.
       if (ConnectionTraits<typename std::remove_reference<decltype(
-          std::get<I>(t))>::type>::IsSelfConnection)
+          std::get<I>(t))>::type>::IsSelfConnection ||
+          ConnectionTraits<typename std::remove_reference<decltype(
+          std::get<I>(t))>::type>::IsFullselfConnection)
       {
         std::get<I>(t).FeedBackward(delta[deltaNum]);
         delta[deltaNum++] = std::get<I>(t).Delta();
@@ -464,7 +494,7 @@ class RNN
       {
         // Sum up the stored delta for recurrent connections.
         if (recurrentLayer[layer])
-          std::get<I>(t).Delta() += delta[deltaNum];
+          std::get<I>(t).Delta() += delta[deltaNum].subvec(0, std::get<I>(t).InputLayer().OutputSize() - 1);
 
         // Perform the backward pass.
         std::get<I>(t).InputLayer().FeedBackward(



More information about the mlpack-git mailing list