[mlpack-git] master: Properly handle the rnn structure. (a18e71c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Feb 27 15:51:54 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/594fd9f61d1280152c758559b4fc60bf0c827cca...45f682337b1daa4c82797f950e16a605fe4971bd

>---------------------------------------------------------------

commit a18e71c3c60e7fd0f6a31f77d734067870390362
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Feb 27 21:11:59 2015 +0100

    Properly handle the rnn structure.


>---------------------------------------------------------------

a18e71c3c60e7fd0f6a31f77d734067870390362
 .../methods/ann/connections/self_connection.hpp    |  27 +-
 src/mlpack/methods/ann/rnn.hpp                     | 434 +++++++++++++++------
 2 files changed, 338 insertions(+), 123 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/self_connection.hpp b/src/mlpack/methods/ann/connections/self_connection.hpp
index fb3e226..3b67a8a 100644
--- a/src/mlpack/methods/ann/connections/self_connection.hpp
+++ b/src/mlpack/methods/ann/connections/self_connection.hpp
@@ -18,7 +18,7 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * Implementation of the self connection class. The self connection connects
  * every neuron from the input layer with the output layer in a multiplicative
- * way.
+ * way, except the elements on the main diagonal.
  *
  * @tparam InputLayerType Type of the connected input layer.
  * @tparam OutputLayerType Type of the connected output layer.
@@ -39,7 +39,7 @@ class SelfConnection
 {
  public:
   /**
-   * Create the FullConnection object using the specified input layer, output
+   * Create the SelfConnection object using the specified input layer, output
    * layer, optimizer and weight initialize rule.
    *
    * @param InputLayerType The input layer which is connected with the output
@@ -54,20 +54,26 @@ class SelfConnection
                  OutputLayerType& outputLayer,
                  OptimizerType& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
-      inputLayer(inputLayer), outputLayer(outputLayer), optimizer(optimizer)
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(optimizer),
+      connection(1 - arma::eye<MatType>(inputLayer.OutputSize(),
+          inputLayer.OutputSize()))
   {
-    weightInitRule.Initialize(weights, outputLayer.OutputSize(), 1);
+    weightInitRule.Initialize(weights, outputLayer.InputSize(),
+        inputLayer.OutputSize());
   }
 
   /**
    * Ordinary feed forward pass of a neural network, evaluating the function
    * f(x) by propagating the activity forward through f.
    *
-   * @param input Input data used for evaluating the specified activity function.
+   * @param input Input data used for evaluating the specified activity
+   * function.
    */
   void FeedForward(const VecType& input)
   {
-    outputLayer.InputActivation() += (weights % input);
+    outputLayer.InputActivation() += (weights % connection) * input;
   }
 
   /**
@@ -79,9 +85,7 @@ class SelfConnection
    */
   void FeedBackward(const VecType& error)
   {
-    // Calculating the delta using the partial derivative of the error with
-    // respect to a weight.
-    delta = (weights.t() * error);
+    delta = (weights % connection).t() * error;
   }
 
   /*
@@ -91,7 +95,7 @@ class SelfConnection
    */
   void Gradient(MatType& gradient)
   {
-    gradient = outputLayer.Delta() % inputLayer.InputActivation();
+    gradient = outputLayer.Delta() * inputLayer.InputActivation().t();
   }
 
   //! Get the weights.
@@ -134,6 +138,9 @@ class SelfConnection
 
   //! Locally-stored detla object that holds the calculated delta.
   VecType delta;
+
+  //! Locally-stored connection multiplication type.
+  MatType connection;
 }; // class SelfConnection
 
 //! Connection traits for the self connection.
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index c616aeb..a4a085b 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -48,7 +48,8 @@ class RNN
      * @param outputLayer The outputlayer used to evaluate the network.
      */
     RNN(const ConnectionTypes& network, OutputLayerType& outputLayer) :
-        network(network), outputLayer(outputLayer)
+        network(network), err(0),  trainError(0), seqNum(0),
+        outputLayer(outputLayer)
     {
       // Nothing to do here.
     }
@@ -71,31 +72,66 @@ class RNN
     {
       // Initialize the activation storage only once.
       if (!activations.size())
-        InitLayer(network, input, target);
+      {
+        InitLayer(network, input);
+      }
+      else
+      {
+        // Expand the activation storage to handle sequences of
+        // different length.
+        if (activations[0].n_cols < input.n_elem)
+        {
+          for (size_t i = 0; i < activations.size(); i++)
+          {
+            activations[i].insert_cols(activations[i].n_cols,
+                arma::zeros<MatType>(activations[i].n_rows,
+                input.n_elem - activations[i].n_cols));
+          }
+        }
+      }
 
-      // Reset the overall error.
-      err = 0;
-      error = MatType(target.n_elem, input.n_rows);
+      seqLen = input.n_rows / inputSize;
+      seqOutput = outputSize < target.n_elem ? true : false;
+      error = MatType(outputSize, outputSize < target.n_elem ? seqLen : 1);
 
       // Iterate through the input sequence and perform the feed forward pass.
-      for (seqNum = 0; seqNum < input.n_rows; seqNum++)
+      for (seqNum = 0; seqNum < seqLen; seqNum++)
       {
-        // Reset the network by zeroing the layer activations and set the input
-        // activation.
+        // Reset the network by zeroing the layer activations.
         ResetActivations(network);
+
+        // Set the current input activation.
         std::get<0>(std::get<0>(
-            network)).InputLayer().InputActivation() = input(seqNum);
+            network)).InputLayer().InputActivation() = input.submat(
+            seqNum * inputSize, 0, (seqNum + 1) * inputSize - 1, 0);
 
-        arma::colvec seqError = error.unsafe_col(seqNum);
-        FeedForward(network, target, seqError);
+        // Perform the forward pass and calculate the output error.
+        FeedForward(network);
+        if (seqOutput)
+        {
+          arma::colvec seqError = error.unsafe_col(seqNum);
+          arma::colvec seqTarget = target.subvec(seqNum * outputSize,
+              (seqNum + 1) * outputSize - 1);
 
-        // Save the network activation for the backward pass.
-        if (seqNum < (input.n_rows - 1))
+          OutputError(network, seqTarget, seqError);
+        }
+
+        // Save the network activation for the backward/forward pass and update
+        // the recurrent connections.
+        if (seqNum < (input.n_rows / inputSize - 1))
         {
           layerNum = 0;
           SaveActivations(network);
         }
       }
+
+      // Calculate the error only once for a non-sequence input.
+      if (!seqOutput)
+      {
+        seqNum = 0;
+        arma::colvec seqError = error.unsafe_col(seqNum);
+        OutputError(network, target, seqError);
+      }
     }
 
     /**
@@ -106,19 +142,20 @@ class RNN
      */
     void FeedBackward(const MatType& error)
     {
-      // Reset the network gradients by zeroing the storage.
-      for (size_t i = 0; i < gradients.size(); ++i)
-        gradients[i].zeros();
-
       // Reset the network deltas by zeroing the storage.
-      for (size_t i = 0; i < delta.size(); ++i)
-        delta[i].zeros();
+      for (size_t i = 0; i < delta.size(); i++)
+          delta[i].zeros();
 
       // Iterate through the input sequence and perform the feed backward pass.
-      for (seqNum = error.n_cols - 1; seqNum >= 0; seqNum--)
+      for (seqNum = seqLen - 1; seqNum >= 0; seqNum--)
       {
         gradientNum = 0;
-        FeedBackward(network, error.unsafe_col(seqNum));
+        deltaNum = 0;
+
+        // Perform the backward pass and update the gradient storage.
+        arma::colvec seqError = error.unsafe_col(seqOutput ? seqNum : 0);
+        FeedBackward(network, seqError);
+        UpdateGradients(network);
 
         // Load the network activation for the upcoming backward pass.
         if (seqNum > 0)
@@ -126,6 +163,7 @@ class RNN
           layerNum = 0;
           LoadActivations(network);
         }
+        else if (seqNum == 0) break;
       }
     }
 
@@ -137,10 +175,79 @@ class RNN
     {
       gradientNum = 0;
       ApplyGradients(network);
+
+      // Reset the overall error.
+      err = 0;
+      trainError = 0;
+      seqNum = 0;
+    }
+
+    /**
+     * Evaluate the network using the given input. The output activation is
+     * stored into the output parameter.
+     *
+     * @param input Input data used to evaluate the network.
+     * @param output Output data used to store the output activation
+     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+     */
+    template <typename VecType>
+    void Predict(const VecType& input, VecType& output)
+    {
+      seqLen = input.n_rows / inputSize;
+
+      // Iterate through the input sequence and perform the feed forward pass.
+      for (seqNum = 0; seqNum < seqLen; seqNum++)
+      {
+        // Reset the network by zeroing the layer activations.
+        ResetActivations(network);
+
+        // Set the current input activation.
+        std::get<0>(std::get<0>(
+            network)).InputLayer().InputActivation() = input.submat(
+            seqNum * inputSize, 0, (seqNum + 1) * inputSize - 1, 0);
+
+        // Perform the forward pass and calculate the output error.
+        FeedForward(network);
+        if (seqOutput)
+        {
+          arma::colvec targetCol;
+          OutputPrediction(network, targetCol);
+          output = arma::join_cols(output, targetCol);
+        }
+
+        // Save the network activation for the backward/forward pass and update
+        // the recurrent connections.
+        if (seqNum < (input.n_rows / inputSize - 1))
+        {
+          layerNum = 0;
+          SaveActivations(network);
+        }
+      }
+
+      if (!seqOutput)
+        OutputPrediction(network, output);
+    }
+
+    /**
+     * Evaluate the trained network using the given input and compare the output
+     * with the given target vector.
+     *
+     * @param input Input data used to evaluate the trained network.
+     * @param target Target data used to calculate the network error.
+     * @param error The calulated error of the output layer.
+     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+     */
+    template <typename VecType>
+    double Evaluate(const MatType& input,
+                     const VecType& target,
+                     MatType& error)
+    {
+      FeedForward(input, target, error);
+      return err;
     }
 
     //! Get the error of the network.
-    double Error() const { return err; }
+    double Error() const { return trainError; }
 
   private:
     /**
@@ -179,6 +286,17 @@ class RNN
     {
       std::get<I>(t).OutputLayer().InputActivation().zeros(
           std::get<I>(t).OutputLayer().InputSize());
+
+      // Reset the recurrent connection only at the beginning of a new sequence.
+      if (seqNum == 0 && (ConnectionTraits<typename std::remove_reference<
+          decltype(std::get<I>(t))>::type>::IsSelfConnection ||
+          ConnectionTraits<typename std::remove_reference<decltype(
+          std::get<I>(t))>::type>::IsFullselfConnection))
+      {
+        std::get<I>(t).InputLayer().InputActivation().zeros(
+          std::get<I>(t).InputLayer().InputSize());
+      }
+
       Reset<I + 1, Tp...>(t);
     }
 
@@ -192,38 +310,13 @@ class RNN
      * connections, and one for the general case which peels off the first type
      * and recurses, as usual with variadic function templates.
      */
-    template<size_t I = 0,
-             typename TargetVecType,
-             typename ErrorVecType,
-             typename... Tp>
+    template<size_t I = 0, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    FeedForward(std::tuple<Tp...>& t,
-                TargetVecType& target,
-                ErrorVecType& error)
-    {
-      // Calculate and store the output error.
-      outputLayer.calculateError(std::get<0>(
-          std::get<I - 1>(t)).OutputLayer().InputActivation(), target,
-          error);
+    FeedForward(std::tuple<Tp...>& /* unused */) { }
 
-      // Save the output activation for the upcoming feed backward pass.
-      activations.back().unsafe_col(seqNum) = std::get<0>(
-          std::get<I - 1>(t)).OutputLayer().InputActivation();
-
-      // Masures the network's performance with the specified performance
-      // function.
-      err = PerformanceFunction::error(std::get<0>(
-          std::get<I - 1>(t)).OutputLayer().InputActivation(), target);
-    }
-
-    template<size_t I = 0,
-            typename TargetVecType,
-            typename ErrorVecType,
-            typename... Tp>
+    template<size_t I = 0, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    FeedForward(std::tuple<Tp...>& t,
-                TargetVecType& target,
-                ErrorVecType& error)
+    FeedForward(std::tuple<Tp...>& t)
     {
       Forward(std::get<I>(t));
 
@@ -232,7 +325,7 @@ class RNN
           std::get<0>(std::get<I>(t)).OutputLayer().InputActivation(),
           std::get<0>(std::get<I>(t)).OutputLayer().InputActivation());
 
-      FeedForward<I + 1, TargetVecType, ErrorVecType, Tp...>(t, target, error);
+      FeedForward<I + 1, Tp...>(t);
     }
 
     /**
@@ -254,6 +347,45 @@ class RNN
       Forward<I + 1, Tp...>(t);
     }
 
+    /*
+     * Calculate the output error and update the overall error.
+     */
+    template<typename VecType, typename... Tp>
+    void OutputError(std::tuple<Tp...>& t,
+                     const VecType& target,
+                     VecType& error)
+    {
+       // Calculate and store the output error.
+      outputLayer.calculateError(std::get<0>(
+          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
+          target, error);
+
+      // Save the output activation for the upcoming feed backward pass.
+      activations.back().unsafe_col(seqNum) = std::get<0>(
+          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation();
+
+      // Masures the network's performance with the specified performance
+      // function.
+      err = PerformanceFunction::error(std::get<0>(
+          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
+          target);
+
+      // Update the overall training error.
+      trainError += err;
+    }
+
+    /*
+     * Calculate and store the output activation.
+     */
+    template<typename VecType, typename... Tp>
+    void OutputPrediction(std::tuple<Tp...>& t, VecType& output)
+    {
+       // Calculate and store the output prediction.
+      outputLayer.outputClass(std::get<0>(
+          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
+          output);
+    }
+
     /**
      * Run a single iteration of the feed backward algorithm, using the given
      * error of the output layer. Note that we iterate backward through the
@@ -265,11 +397,11 @@ class RNN
      * and recurses, as usual with variadic function templates.
      */
     template<size_t I = 0, typename VecType, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
+    typename std::enable_if<I == sizeof...(Tp) + 1, void>::type
     FeedBackward(std::tuple<Tp...>& /* unused */, VecType& /* unused */) { }
 
     template<size_t I = 1, typename VecType, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
+    typename std::enable_if<I < sizeof...(Tp) + 1, void>::type
     FeedBackward(std::tuple<Tp...>& t, VecType& error)
     {
       // Distinguish between the output layer and the other layer. In case of
@@ -282,20 +414,10 @@ class RNN
         std::get<0>(std::get<sizeof...(Tp) - I>(t)).OutputLayer().FeedBackward(
             activations.back().unsafe_col(seqNum), error,
             std::get<0>(std::get<sizeof...(Tp) - I>(t)).OutputLayer().Delta());
-
-        // Save the delta for the upcoming feed backward pass.
-        delta.back() += std::get<0>(
-            std::get<sizeof...(Tp) - I>(t)).OutputLayer().Delta();
-
-        // Save the gradient to update the weights at the end.
-        gradients.back() += std::get<0>(
-            std::get<sizeof...(Tp) - I>(t)).OutputLayer().Delta() *
-            std::get<0>(
-            std::get<sizeof...(Tp) - I>(t)).InputLayer().InputActivation().t();
       }
 
-      Backward(std::get<sizeof...(Tp) - I>(t), delta[delta.size() - I]);
-      UpdateGradients(std::get<sizeof...(Tp) - I - 1>(t));
+      Backward(std::get<sizeof...(Tp) - I>(t), std::get<0>(std::get<
+          sizeof...(Tp) - I>(t)).OutputLayer().Delta(), I, sizeof...(Tp));
 
       FeedBackward<I + 1, VecType, Tp...>(t, error);
     }
@@ -310,28 +432,71 @@ class RNN
      */
     template<size_t I = 0, typename VecType, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    Backward(std::tuple<Tp...>& /* unused */, VecType& /* unused */) { }
+    Backward(std::tuple<Tp...>& /* unused */,
+             VecType& /* unused */,
+             const size_t /* unused */,
+             const size_t /* unused */) { }
 
     template<size_t I = 0, typename VecType, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    Backward(std::tuple<Tp...>& t, VecType& error)
+    Backward(std::tuple<Tp...>& t,
+             VecType& error,
+             const size_t layer,
+             const size_t layerNum)
     {
       std::get<I>(t).FeedBackward(error);
 
+      // Update the recurrent delta.
+      if (ConnectionTraits<typename std::remove_reference<decltype(
+          std::get<I>(t))>::type>::IsSelfConnection)
+      {
+        std::get<I>(t).FeedBackward(delta[deltaNum]);
+        delta[deltaNum++] = std::get<I>(t).Delta();
+      }
+
       // We calculate the delta only for non bias layer and self connections.
       if (!(ConnectionTraits<typename std::remove_reference<decltype(
             std::get<I>(t))>::type>::IsSelfConnection ||
         LayerTraits<typename std::remove_reference<decltype(
             std::get<I>(t).InputLayer())>::type>::IsBiasLayer ||
         ConnectionTraits<typename std::remove_reference<decltype(
-            std::get<I>(t))>::type>::IsFullselfConnection))
+            std::get<I>(t))>::type>::IsFullselfConnection) && layer < layerNum)
       {
+        // Sum up the stored delta for recurrent connections.
+        if (recurrentLayer[layer])
+          std::get<I>(t).Delta() += delta[deltaNum];
+
+        // Perform the backward pass.
         std::get<I>(t).InputLayer().FeedBackward(
             std::get<I>(t).InputLayer().InputActivation(),
             std::get<I>(t).Delta(), std::get<I>(t).InputLayer().Delta());
+
+        // Update the delta storage for the next backward pass.
+        if (recurrentLayer[layer])
+          delta[deltaNum] = std::get<I>(t).InputLayer().Delta();
       }
 
-      Backward<I + 1, VecType, Tp...>(t, error);
+      Backward<I + 1, VecType, Tp...>(t, error, layer, layerNum);
+    }
+
+    /**
+     * Helper function to update the gradient storage.
+     *
+     * enable_if (SFINAE) is used to select between two template overloads of
+     * the get function - one for when I is equal the size of the tuple of
+     * connections, and one for the general case which peels off the first type
+     * and recurses, as usual with variadic function templates.
+     */
+    template<size_t I = 0, typename... Tp>
+    typename std::enable_if<I == sizeof...(Tp), void>::type
+    UpdateGradients(std::tuple<Tp...>& /* unused */) { }
+
+    template<size_t I = 0, typename... Tp>
+    typename std::enable_if<I < sizeof...(Tp), void>::type
+    UpdateGradients(std::tuple<Tp...>& t)
+    {
+      Gradients(std::get<I>(t));
+      UpdateGradients<I + 1, Tp...>(t);
     }
 
     /**
@@ -343,16 +508,17 @@ class RNN
      */
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    UpdateGradients(std::tuple<Tp...>& /* unused */) { }
+    Gradients(std::tuple<Tp...>& /* unused */) { }
 
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    UpdateGradients(std::tuple<Tp...>& t)
+    Gradients(std::tuple<Tp...>& t)
     {
-      gradients[gradientNum++] += std::get<I>(t).OutputLayer().Delta() *
-          std::get<I>(t).InputLayer().InputActivation().t();
+      MatType gradient;
+      std::get<I>(t).Gradient(gradient);
+      gradients[gradientNum++] += gradient;
 
-      UpdateGradients<I + 1, Tp...>(t);
+      Gradients<I + 1, Tp...>(t);
     }
 
     /**
@@ -365,14 +531,14 @@ class RNN
      * and recurses, as usual with variadic function templates.
      */
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp) - 1, void>::type
+    typename std::enable_if<I == sizeof...(Tp), void>::type
     ApplyGradients(std::tuple<Tp...>& /* unused */) { }
 
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp) - 1, void>::type
+    typename std::enable_if<I < sizeof...(Tp), void>::type
     ApplyGradients(std::tuple<Tp...>& t)
     {
-      Gradients(std::get<I>(t));
+      Apply(std::get<I>(t));
       ApplyGradients<I + 1, Tp...>(t);
     }
 
@@ -387,16 +553,19 @@ class RNN
      */
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    Gradients(std::tuple<Tp...>& /* unused */) { }
+    Apply(std::tuple<Tp...>& /* unused */) { }
 
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    Gradients(std::tuple<Tp...>& t)
+    Apply(std::tuple<Tp...>& t)
     {
       std::get<I>(t).Optimzer().UpdateWeights(std::get<I>(t).Weights(),
-          gradients[gradientNum++], err);
+          gradients[gradientNum], trainError);
 
-      Gradients<I + 1, Tp...>(t);
+      // // Reset the gradient storage.
+      gradients[gradientNum++].zeros();
+
+      Apply<I + 1, Tp...>(t);
     }
 
     /**
@@ -408,21 +577,49 @@ class RNN
      * connections, and one for the general case which peels off the first type
      * and recurses, as usual with variadic function templates.
      */
-    template<size_t I = 0, typename VecType, typename... Tp>
+    template<size_t I = 0, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    InitLayer(std::tuple<Tp...>& /* unused */,
-              const MatType& input,
-              const VecType& target)
+    InitLayer(std::tuple<Tp...>& t, const MatType& input)
     {
-      activations.push_back(new MatType(target.n_elem, input.n_elem));
+      recurrentLayer.push_back(false);
+      outputSize = std::get<0>(std::get<I - 1>(t)).OutputLayer().OutputSize();
+      activations.push_back(new MatType(outputSize, input.n_elem));
     }
 
-    template<size_t I = 0, typename VecType, typename... Tp>
+    template<size_t I = 0, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    InitLayer(std::tuple<Tp...>& t, const MatType& input, const VecType& target)
+    InitLayer(std::tuple<Tp...>& t, const MatType& input)
     {
-      Layer(std::get<I>(t), input);
-      InitLayer<I + 1, VecType, Tp...>(t, input, target);
+      if (I == 0)
+        inputSize = std::get<0>(std::get<I>(t)).InputLayer().InputSize();
+
+      recurrentLayer.push_back(false);
+      Recurrent(std::get<sizeof...(Tp) - I - 1>(t));
+
+      Layer(std::get<I>(t), input, I);
+      InitLayer<I + 1, Tp...>(t, input);
+    }
+
+    template<size_t I = 0, typename... Tp>
+    typename std::enable_if<I == sizeof...(Tp), void>::type
+    Recurrent(std::tuple<Tp...>& /* unusded */) { }
+
+    template<size_t I = 0, typename... Tp>
+    typename std::enable_if<I < sizeof...(Tp), void>::type
+    Recurrent(std::tuple<Tp...>& t)
+    {
+      if (ConnectionTraits<typename std::remove_reference<decltype(
+              std::get<I>(t))>::type>::IsSelfConnection ||
+          ConnectionTraits<typename std::remove_reference<decltype(
+            std::get<I>(t))>::type>::IsFullselfConnection)
+      {
+        recurrentLayer.back() = true;
+        delta.push_back(new VecTypeDelta(std::get<I>(t).Weights().n_rows));
+      }
+      else
+      {
+        Recurrent<I + 1, Tp...>(t);
+      }
     }
 
     /**
@@ -436,30 +633,21 @@ class RNN
      */
     template<size_t I = 0, typename VecType, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    Layer(std::tuple<Tp...>& /* unused */, const VecType& /* unused */) { }
+    Layer(std::tuple<Tp...>& /* unusded */,
+          const VecType& /* unused */,
+          const size_t /* unsued */) { }
 
     template<size_t I = 0, typename VecType, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    Layer(std::tuple<Tp...>& t, const VecType& input)
+    Layer(std::tuple<Tp...>& t, const VecType& input, const size_t layer)
     {
       activations.push_back(new MatType(
         std::get<I>(t).InputLayer().OutputSize(), input.n_elem));
 
       gradients.push_back(new MatType(std::get<I>(t).Weights().n_rows,
-          std::get<I>(t).Weights().n_cols));
+          std::get<I>(t).Weights().n_cols, arma::fill::zeros));
 
-      // We calculate the delta only for non bias layer and self connections.
-      if (!(ConnectionTraits<typename std::remove_reference<decltype(
-              std::get<I>(t))>::type>::IsSelfConnection ||
-          LayerTraits<typename std::remove_reference<decltype(
-              std::get<I>(t).InputLayer())>::type>::IsBiasLayer ||
-          ConnectionTraits<typename std::remove_reference<decltype(
-            std::get<I>(t))>::type>::IsFullselfConnection))
-      {
-        delta.push_back(new VecTypeDelta(std::get<I>(t).Weights().n_rows));
-      }
-
-      Layer<I + 1, VecType, Tp...>(t, input);
+      Layer<I + 1, VecType, Tp...>(t, input, layer);
     }
 
     /**
@@ -554,20 +742,23 @@ class RNN
       Save<I + 1, Tp...>(t);
     }
 
+    //! The layer we are using to build the network.
+    ConnectionTypes network;
+
     //! The current error of the network.
     double err;
 
+    //! The current training error of the network.
+    double trainError;
+
     //! The activation storage we are using to perform the feed backward pass.
     boost::ptr_vector<MatType> activations;
 
     //! The gradient storage we are using to perform the feed backward pass.
     boost::ptr_vector<MatType> gradients;
 
-    //! The detla storage we are using to perform the feed backward pass.
-    boost::ptr_vector<VecTypeDelta> delta;
-
     //! The index of the current sequence number.
-    long int seqNum;
+    size_t seqNum;
 
     //! The index of the currently activate layer.
     size_t layerNum;
@@ -575,11 +766,29 @@ class RNN
     //! The index of the currently activate gradient.
     size_t gradientNum;
 
-    //! The layer we are using to build the network.
-    ConnectionTypes network;
+    //! The index of the currently activate delta.
+    size_t deltaNum;
+
+    //! Locally stored network output size.
+    size_t outputSize;
+
+    //! Locally stored network input size.
+    size_t inputSize;
+
+    //! Locally stored parameter that indicates if the input is a sequence.
+    bool seqOutput;
 
     //! The outputlayer used to evaluate the network
     OutputLayerType& outputLayer;
+
+    //! Locally stored number of samples in one input sequence.
+    size_t seqLen;
+
+    //! The recurrentLayer storage we are using to perform the backward pass.
+    std::vector<bool> recurrentLayer;
+
+    //! The detla storage we are using to perform the feed backward pass.
+    boost::ptr_vector<VecTypeDelta> delta;
 }; // class RNN
 
 //! Network traits for the FFNN network.
@@ -599,4 +808,3 @@ class NetworkTraits<RNN<ConnectionTypes, OutputLayerType, PerformanceFunction> >
 }; // namespace mlpack
 
 #endif
-



More information about the mlpack-git mailing list