[mlpack-git] master: Adjust RNN class so that it works with the mlpack optimizers; Add Train() method regarding the design guidelines. (f6c27ed)

gitdub at mlpack.org gitdub at mlpack.org
Fri Feb 19 08:33:35 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f6dd2f7a9752a7db8ec284a938b3e84a13d0bfb2...6205f3e0b62b56452b2a4afc4da24fce5b21e72f

>---------------------------------------------------------------

commit f6c27ed38633f732d1e8dd34e29e4be79b9d9f87
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Fri Feb 19 14:33:35 2016 +0100

    Adjust RNN class so that it works with the mlpack optimizers; Add Train() method regarding the design guidelines.


>---------------------------------------------------------------

f6c27ed38633f732d1e8dd34e29e4be79b9d9f87
 src/mlpack/methods/ann/rnn.hpp                     | 624 ++++++++++-----------
 .../methods/ann/{ffn_impl.hpp => rnn_impl.hpp}     | 167 ++++--
 2 files changed, 405 insertions(+), 386 deletions(-)

diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 7a22763..2c6210c 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -2,7 +2,7 @@
  * @file rnn.hpp
  * @author Marcus Edel
  *
- * Definition of the RNN class, which implements feed forward neural networks.
+ * Definition of the RNN class, which implements recurrent neural networks.
  */
 #ifndef __MLPACK_METHODS_ANN_RNN_HPP
 #define __MLPACK_METHODS_ANN_RNN_HPP
@@ -11,122 +11,214 @@
 
 #include <boost/ptr_container/ptr_vector.hpp> 
 
-#include <mlpack/methods/ann/network_traits.hpp>
+#include <mlpack/methods/ann/network_util.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
+#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
 #include <mlpack/methods/ann/performance_functions/cee_function.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * An implementation of a standard feed forward network.
+ * Implementation of a standard recurrent neural network.
  *
  * @tparam LayerTypes Contains all layer modules used to construct the network.
  * @tparam OutputLayerType The outputlayer type used to evaluate the network.
- * @tparam PerformanceFunction Performance strategy used to claculate the error.
+ * @tparam InitializationRuleType Rule used to initialize the weight matrix.
+ * @tparam PerformanceFunction Performance strategy used to calculate the error.
  */
 template <
   typename LayerTypes,
   typename OutputLayerType,
-  class PerformanceFunction = CrossEntropyErrorFunction<>,
-  typename MatType = arma::mat
+  typename InitializationRuleType = NguyenWidrowInitialization,
+  class PerformanceFunction = CrossEntropyErrorFunction<>
 >
 class RNN
 {
  public:
+  //! Convenience typedef for the internal model construction.
+  using NetworkType = RNN<LayerTypes,
+                          OutputLayerType,
+                          InitializationRuleType,
+                          PerformanceFunction>;
+
   /**
-   * Construct the RNN object, which will construct a recurrent neural
-   * network with the specified layers.
+   * Create the RNN object with the given predictors and responses set (this is
+   * the set that is used to train the network) and the given optimizer.
+   * Optionally, specify which initialize rule and performance function should
+   * be used.
    *
-   * @param network The network modules used to construct the network.
-   * @param outputLayer The outputlayer used to evaluate the network.
-   * @param performanceFunction Performance strategy used to claculate the error.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param predictors Input training variables.
+   * @param responses Outputs resulting from input training variables.
+   * @param optimizer Instantiated optimizer used to train the model.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  RNN(const LayerTypes& network, OutputLayerType& outputLayer,
-      PerformanceFunction performanceFunction = PerformanceFunction()) :
-      network(network),
-      outputLayer(outputLayer),
-      performanceFunction(std::move(performanceFunction)),
-      trainError(0),
-      inputSize(0),
-      outputSize(0)
-  {
-    // Nothing to do here.
-  }
+  template<typename LayerType,
+           typename OutputType,
+           template<typename> class OptimizerType>
+  RNN(LayerType &&network,
+      OutputType &&outputLayer,
+      const arma::mat& predictors,
+      const arma::mat& responses,
+      OptimizerType<NetworkType>& optimizer,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
   /**
-   * Run a single iteration of the feed forward algorithm, using the given
-   * input and target vector, store the calculated error into the error
-   * parameter.
+   * Create the RNN object with the given predictors and responses set (this is
+   * the set that is used to train the network). Optionally, specify which
+   * initialize rule and performance function should be used.
    *
-   * @param input Input data used to evaluate the network.
-   * @param target Target data used to calculate the network error.
-   * @param error The calulated error of the output layer.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param predictors Input training variables.
+   * @param responses Outputs resulting from input training variables.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  template <typename InputType, typename TargetType, typename ErrorType>
-  void FeedForward(const InputType& input,
-                   const TargetType& target,
-                   ErrorType& error)
-  {
-    deterministic = false;
-    trainError += Evaluate(input, target, error);
-  }
+  template<typename LayerType, typename OutputType>
+  RNN(LayerType &&network,
+      OutputType &&outputLayer,
+      const arma::mat& predictors,
+      const arma::mat& responses,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
   /**
-   * Run a single iteration of the feed backward algorithm, using the given
-   * error of the output layer.
+   * Create the RNN object with an empty predictors and responses set and
+   * default optimizer. Make sure to call Train(predictors, responses) when
+   * training.
    *
-   * @param error The calulated error of the output layer.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  template <typename InputType, typename ErrorType>
-  void FeedBackward(const InputType& input, const ErrorType& error)
-  {
-    // Iterate through the input sequence and perform the feed backward pass.
-    for (seqNum = seqLen - 1; seqNum >= 0; seqNum--)
-    {
-      // Load the network activation for the upcoming backward pass.
-        LoadActivations(input.rows(seqNum * inputSize, (seqNum + 1) *
-            inputSize - 1), network);
+  template<typename LayerType, typename OutputType>
+  RNN(LayerType &&network,
+      OutputType &&outputLayer,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
-      // Perform the backward pass.
-      if (seqOutput)
-      {
-        ErrorType seqError = error.unsafe_col(seqNum);
-        Backward(seqError, network);
-      }
-      else
-      {
-        Backward(error, network);
-      }
+  /**
+   * Train the recurrent neural network on the given input data. By default, the
+   * SGD optimization algorithm is used, but others can be specified
+   * (such as mlpack::optimization::RMSprop).
+   *
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @tparam OptimizerType Type of optimizer to use to train the model.
+   * @param predictors Input training variables.
+   * @param responses Outputs results from input training variables.
+   */
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::SGD
+  >
+  void Train(const arma::mat& predictors, const arma::mat& responses);
 
-      // Link the parameters and update the gradients.
-      LinkParameter(network);
-      UpdateGradients<>(network);
+  /**
+   * Train the recurrent neural network with the given instantiated optimizer.
+   * Using this overload allows configuring the instantiated optimizer before
+   * training is performed.
+   *
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @param optimizer Instantiated optimizer used to train the model.
+   */
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::SGD
+  >
+  void Train(OptimizerType<NetworkType>& optimizer);
 
-      if (seqNum == 0) break;
-    }
-  }
+  /**
+   * Train the recurrent neural network on the given input data using the given
+   * optimizer.
+   *
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @tparam OptimizerType Type of optimizer to use to train the model.
+   * @param predictors Input training variables.
+   * @param responses Outputs results from input training variables.
+   * @param optimizer Instantiated optimizer used to train the model.
+   */
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::SGD
+  >
+  void Train(const arma::mat& predictors,
+             const arma::mat& responses,
+             OptimizerType<NetworkType>& optimizer);
 
   /**
-   * Update the weights using the layer defined optimizer.
+   * Predict the responses to a given set of predictors. The responses will
+   * reflect the output of the given output layer as returned by the
+   * OutputClass() function.
+   *
+   * @param predictors Input predictors.
+   * @param responses Matrix to put output predictions of responses into.
    */
-  void ApplyGradients()
-  {
-    ApplyGradients<>(network);
+  void Predict(arma::mat& predictors, arma::mat& responses);
 
-    // Reset the overall error.
-    trainError = 0;
-  }
+  /**
+   * Evaluate the recurrent neural network with the given parameters. This
+   * function is usually called by the optimizer to train the model.
+   *
+   * @param parameters Matrix model parameters.
+   * @param i Index of point to use for objective function evaluation.
+   * @param deterministic Whether or not to train or test the model. Note some
+   * layer act differently in training or testing mode.
+   */
+  double Evaluate(const arma::mat& parameters,
+                  const size_t i,
+                  const bool deterministic = false);
 
   /**
-   * Evaluate the network using the given input. The output activation is
-   * stored into the output parameter.
+   * Evaluate the gradient of the recurrent neural network with the given
+   * parameters, and with respect to only one point in the dataset. This is
+   * useful for optimizers such as SGD, which require a separable objective
+   * function.
    *
-   * @param input Input data used to evaluate the network.
-   * @param output Output data used to store the output activation
+   * @param parameters Matrix of the model parameters to be optimized.
+   * @param i Index of points to use for objective function gradient evaluation.
+   * @param gradient Matrix to output gradient into.
+   */
+  void Gradient(const arma::mat& parameters,
+                const size_t i,
+                arma::mat& gradient);
+
+  //! Return the number of separable functions (the number of predictor points).
+  size_t NumFunctions() const { return numFunctions; }
+
+  //! Return the initial point for the optimization.
+  const arma::mat& Parameters() const { return parameter; }
+  //! Modify the initial point for the optimization.
+  arma::mat& Parameters() { return parameter; }
+
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+ private:
+  /*
+   * Predict the response of the given input matrix.
    */
   template <typename DataType>
-  void Predict(const DataType& input, DataType& output)
+  void SinglePredict(const DataType& input, DataType& output)
   {
     deterministic = true;
     seqLen = input.n_rows / inputSize;
@@ -155,68 +247,8 @@ class RNN
   }
 
   /**
-   * Evaluate the network using the given input and compare the output with the
-   * given target vector.
-   *
-   * @param input Input data used to evaluate the trained network.
-   * @param target Target data used to calculate the network error.
-   * @param error The calulated error of the output layer.
-   */
-  template <typename InputType, typename TargetType, typename ErrorType>
-  double Evaluate(const InputType& input,
-                  const TargetType& target,
-                  ErrorType& error)
-  {
-    // Initialize the activation storage only once.
-    if (activations.empty())
-      InitLayer(input, target, network);
-
-    double networkError = 0;
-    seqLen = input.n_rows / inputSize;
-    deterministic = false;
-    ResetParameter(network);
-
-    error = ErrorType(outputSize, outputSize < target.n_elem ? seqLen : 1);
-
-    // Iterate through the input sequence and perform the feed forward pass.
-    for (seqNum = 0; seqNum < seqLen; seqNum++)
-    {
-      // Perform the forward pass and save the activations.
-      Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
-          network);
-      SaveActivations(network);
-
-      // Retrieve output error of the subsequence.
-      if (seqOutput)
-      {
-        arma::mat seqError = error.unsafe_col(seqNum);
-        arma::mat seqTarget = target.submat(seqNum * outputSize, 0,
-            (seqNum + 1) * outputSize - 1, 0);
-        networkError += OutputError(seqTarget, seqError, network);
-      }
-    }
-
-    // Retrieve output error of the complete sequence.
-    if (!seqOutput)
-      return OutputError(target, error, network);
-
-    return networkError;
-  }
-
-  //! Get the error of the network.
-  double Error() const
-  {
-    return trainError;
-  }
-
- private:
-  /**
    * Reset the network by clearing the layer activations and by setting the
    * layer status.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
@@ -227,14 +259,14 @@ class RNN
 
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  ResetParameter(std::tuple<Tp...>& t)
+  ResetParameter(std::tuple<Tp...>& network)
   {
-    ResetDeterministic(std::get<I>(t));
-    ResetSeqLen(std::get<I>(t));
-    ResetRecurrent(std::get<I>(t), std::get<I>(t).InputParameter());
-    std::get<I>(t).Delta().zeros();
+    ResetDeterministic(std::get<I>(network));
+    ResetSeqLen(std::get<I>(network));
+    ResetRecurrent(std::get<I>(network), std::get<I>(network).InputParameter());
+    std::get<I>(network).Delta().zeros();
 
-    ResetParameter<I + 1, Tp...>(t);
+    ResetParameter<I + 1, Tp...>(network);
   }
 
   /**
@@ -244,9 +276,9 @@ class RNN
   template<typename T>
   typename std::enable_if<
       HasDeterministicCheck<T, bool&(T::*)(void)>::value, void>::type
-  ResetDeterministic(T& t)
+  ResetDeterministic(T& layer)
   {
-    t.Deterministic() = deterministic;
+    layer.Deterministic() = deterministic;
   }
 
   template<typename T>
@@ -261,9 +293,9 @@ class RNN
   template<typename T>
   typename std::enable_if<
       HasSeqLenCheck<T, size_t&(T::*)(void)>::value, void>::type
-  ResetSeqLen(T& t)
+  ResetSeqLen(T& layer)
   {
-    t.SeqLen() = seqLen;
+    layer.SeqLen() = seqLen;
   }
 
   template<typename T>
@@ -278,9 +310,9 @@ class RNN
   template<typename T, typename P>
   typename std::enable_if<
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  ResetRecurrent(T& t, P& /* unused */)
+  ResetRecurrent(T& layer, P& /* unused */)
   {
-    t.RecurrentParameter().zeros();
+    layer.RecurrentParameter().zeros();
   }
 
   template<typename T, typename P>
@@ -293,10 +325,6 @@ class RNN
 
   /**
    * Initialize the network by setting the input size and output size.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, typename InputDataType, typename TargetDataType,
       typename... Tp>
@@ -313,12 +341,13 @@ class RNN
   typename std::enable_if<I < sizeof...(Tp) - 1, void>::type
   InitLayer(const InputDataType& input,
             const TargetDataType& target,
-            std::tuple<Tp...>& t)
+            std::tuple<Tp...>& network)
   {
-    Init(std::get<I>(t), std::get<I>(t).OutputParameter(),
-       std::get<I + 1>(t).Delta());
+    Init(std::get<I>(network), std::get<I>(network).OutputParameter(),
+       std::get<I + 1>(network).Delta());
     
-    InitLayer<I + 1, InputDataType, TargetDataType, Tp...>(input, target, t);
+    InitLayer<I + 1, InputDataType, TargetDataType, Tp...>(input, target,
+        network);
   }
 
   /**
@@ -328,13 +357,13 @@ class RNN
   template<typename T, typename P, typename D>
   typename std::enable_if<
       HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
-  Init(T& t, P& /* unused */, D& /* unused */)
+  Init(T& layer, P& /* unused */, D& /* unused */)
   {
     // Initialize the input size only once.
     if (!inputSize)
-      inputSize = t.Weights().n_cols;
+      inputSize = layer.Weights().n_cols;
 
-    outputSize = t.Weights().n_rows;
+    outputSize = layer.Weights().n_rows;
   }
 
   template<typename T, typename P, typename D>
@@ -347,10 +376,6 @@ class RNN
 
   /**
    * Save the network layer activations.
-   *
-   * enable_if (SFINAE) is used to iterate through the network layer.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
@@ -361,10 +386,10 @@ class RNN
 
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  SaveActivations(std::tuple<Tp...>& t)
+  SaveActivations(std::tuple<Tp...>& network)
   {
-    Save(I, std::get<I>(t), std::get<I>(t).InputParameter());
-    SaveActivations<I + 1, Tp...>(t);
+    Save(I, std::get<I>(network), std::get<I>(network).InputParameter());
+    SaveActivations<I + 1, Tp...>(network);
   }
 
   /**
@@ -374,45 +399,47 @@ class RNN
   template<typename T, typename P>
   typename std::enable_if<
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  Save(const size_t layerNumber, T& t, P& /* unused */)
+  Save(const size_t layerNumber, T& layer, P& /* unused */)
   {
     if (activations.size() == layerNumber)
-      activations.push_back(new MatType(t.RecurrentParameter().n_rows, seqLen));
+    {
+      activations.push_back(new arma::mat(layer.RecurrentParameter().n_rows,
+          seqLen));
+    }
 
-    activations[layerNumber].unsafe_col(seqNum) = t.RecurrentParameter();
+    activations[layerNumber].unsafe_col(seqNum) = layer.RecurrentParameter();
   }
 
   template<typename T, typename P>
   typename std::enable_if<
       !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  Save(const size_t layerNumber, T& t, P& /* unused */)
+  Save(const size_t layerNumber, T& layer, P& /* unused */)
   {
     if (activations.size() == layerNumber)
-      activations.push_back(new MatType(t.OutputParameter().n_rows, seqLen));
+    {
+      activations.push_back(new arma::mat(layer.OutputParameter().n_rows,
+          seqLen));
+    }
 
-    activations[layerNumber].unsafe_col(seqNum) = t.OutputParameter();
+    activations[layerNumber].unsafe_col(seqNum) = layer.OutputParameter();
   }
 
   /**
    * Load the network layer activations.
-   *
-   * enable_if (SFINAE) is used to iterate through the network layer.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, typename DataType, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
-  LoadActivations(DataType& input, std::tuple<Tp...>& t)
+  LoadActivations(DataType& input, std::tuple<Tp...>& network)
   {
-    std::get<0>(t).InputParameter() = input;
+    std::get<0>(network).InputParameter() = input;
   }
 
   template<size_t I = 0, typename DataType, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  LoadActivations(DataType& input, std::tuple<Tp...>& t)
+  LoadActivations(DataType& input, std::tuple<Tp...>& network)
   {
-    Load(I, std::get<I>(t), std::get<I>(t).InputParameter());
-    LoadActivations<I + 1, DataType, Tp...>(input, t);
+    Load(I, std::get<I>(network), std::get<I>(network).InputParameter());
+    LoadActivations<I + 1, DataType, Tp...>(input, network);
   }
 
   /**
@@ -422,37 +449,32 @@ class RNN
   template<typename T, typename P>
   typename std::enable_if<
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  Load(const size_t layerNumber, T& t, P& /* unused */)
+  Load(const size_t layerNumber, T& layer, P& /* unused */)
   {
-    t.RecurrentParameter() = activations[layerNumber].unsafe_col(seqNum);
+    layer.RecurrentParameter() = activations[layerNumber].unsafe_col(seqNum);
   }
 
   template<typename T, typename P>
   typename std::enable_if<
       !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  Load(const size_t layerNumber, T& t, P& /* unused */)
+  Load(const size_t layerNumber, T& layer, P& /* unused */)
   {
-    t.OutputParameter() = activations[layerNumber].unsafe_col(seqNum);
+    layer.OutputParameter() = activations[layerNumber].unsafe_col(seqNum);
   }
 
   /**
    * Run a single iteration of the feed forward algorithm, using the given
    * input and target vector, store the calculated error into the error
    * vector.
-   *
-   * enable_if (SFINAE) is used to select between two template overloads of
-   * the get function - one for when I is equal the size of the tuple of
-   * layer, and one for the general case which peels off the first type
-   * and recurses, as usual with variadic function templates.
    */
   template<size_t I = 0, typename DataType, typename... Tp>
-  void Forward(const DataType& input, std::tuple<Tp...>& t)
+  void Forward(const DataType& input, std::tuple<Tp...>& network)
   {
-    std::get<I>(t).InputParameter() = input;
-    std::get<I>(t).Forward(std::get<I>(t).InputParameter(),
-        std::get<I>(t).OutputParameter());
+    std::get<I>(network).InputParameter() = input;
+    std::get<I>(network).Forward(std::get<I>(network).InputParameter(),
+        std::get<I>(network).OutputParameter());
 
-    ForwardTail<I + 1, Tp...>(t);
+    ForwardTail<I + 1, Tp...>(network);
   }
 
   template<size_t I = 1, typename... Tp>
@@ -461,20 +483,16 @@ class RNN
 
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  ForwardTail(std::tuple<Tp...>& t)
+  ForwardTail(std::tuple<Tp...>& network)
   {
-    std::get<I>(t).Forward(std::get<I - 1>(t).OutputParameter(),
-        std::get<I>(t).OutputParameter());
+    std::get<I>(network).Forward(std::get<I - 1>(network).OutputParameter(),
+        std::get<I>(network).OutputParameter());
 
-    ForwardTail<I + 1, Tp...>(t);
+    ForwardTail<I + 1, Tp...>(network);
   }
 
   /**
    * Link the calculated activation with the correct layer.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
@@ -482,23 +500,20 @@ class RNN
 
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  LinkParameter(std::tuple<Tp...>& t)
+  LinkParameter(std::tuple<Tp...>& network)
   {
     if (!LayerTraits<typename std::remove_reference<
-        decltype(std::get<I>(t))>::type>::IsBiasLayer)
+        decltype(std::get<I>(network))>::type>::IsBiasLayer)
     {
-      std::get<I>(t).InputParameter() = std::get<I - 1>(t).OutputParameter();
+      std::get<I>(network).InputParameter() = std::get<I - 1>(
+          network).OutputParameter();
     }
 
-    LinkParameter<I + 1, Tp...>(t);
+    LinkParameter<I + 1, Tp...>(network);
   }
 
   /**
    * Link the calculated activation with the correct recurrent layer.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I == (sizeof...(Tp) - 1), void>::type
@@ -506,11 +521,11 @@ class RNN
 
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
-  LinkRecurrent(std::tuple<Tp...>& t)
+  LinkRecurrent(std::tuple<Tp...>& network)
   {
-    UpdateRecurrent(std::get<I>(t), std::get<I>(t).InputParameter(),
-        std::get<I + 1>(t).OutputParameter());
-    LinkRecurrent<I + 1, Tp...>(t);
+    UpdateRecurrent(std::get<I>(network), std::get<I>(network).InputParameter(),
+        std::get<I + 1>(network).OutputParameter());
+    LinkRecurrent<I + 1, Tp...>(network);
   }
 
   /**
@@ -520,9 +535,9 @@ class RNN
   template<typename T, typename P, typename D>
   typename std::enable_if<
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  UpdateRecurrent(T& t, P& /* unused */, D& output)
+  UpdateRecurrent(T& layer, P& /* unused */, D& output)
   {
-    t.RecurrentParameter() = output;
+    layer.RecurrentParameter() = output;
   }
 
   template<typename T, typename P, typename D>
@@ -539,36 +554,31 @@ class RNN
   template<typename DataType, typename ErrorType, typename... Tp>
   double OutputError(const DataType& target,
                      ErrorType& error,
-                     const std::tuple<Tp...>& t)
+                     const std::tuple<Tp...>& network)
   {
     // Calculate and store the output error.
     outputLayer.CalculateError(
-        std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target, error);
+        std::get<sizeof...(Tp) - 1>(network).OutputParameter(), target, error);
 
     // Masures the network's performance with the specified performance
     // function.
-    return performanceFunction.Error(network, target, error);
+    return performanceFunc.Error(network, target, error);
   }
 
   /**
    * Run a single iteration of the feed backward algorithm, using the given
    * error of the output layer. Note that we iterate backward through the
    * layer modules.
-   *
-   * enable_if (SFINAE) is used to select between two template overloads of
-   * the get function - one for when I is equal the size of the tuple of
-   * layer, and one for the general case which peels off the first type
-   * and recurses, as usual with variadic function templates.
    */
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
-  Backward(const DataType& error, std::tuple<Tp ...>& t)
+  Backward(DataType& error, std::tuple<Tp ...>& network)
   {
-    std::get<sizeof...(Tp) - I>(t).Backward(
-        std::get<sizeof...(Tp) - I>(t).OutputParameter(), error,
-        std::get<sizeof...(Tp) - I>(t).Delta());
+    std::get<sizeof...(Tp) - I>(network).Backward(
+        std::get<sizeof...(Tp) - I>(network).OutputParameter(), error,
+        std::get<sizeof...(Tp) - I>(network).Delta());
 
-    BackwardTail<I + 1, DataType, Tp...>(error, t);
+    BackwardTail<I + 1, DataType, Tp...>(error, network);
   }
 
   template<size_t I = 1, typename DataType, typename... Tp>
@@ -580,18 +590,18 @@ class RNN
 
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp)), void>::type
-  BackwardTail(const DataType& error, std::tuple<Tp...>& t)
+  BackwardTail(const DataType& error, std::tuple<Tp...>& network)
   {
-    BackwardRecurrent(std::get<sizeof...(Tp) - I - 1>(t),
-        std::get<sizeof...(Tp) - I - 1>(t).InputParameter(),
-        std::get<sizeof...(Tp) - I + 1>(t).Delta());
+    BackwardRecurrent(std::get<sizeof...(Tp) - I - 1>(network),
+        std::get<sizeof...(Tp) - I - 1>(network).InputParameter(),
+        std::get<sizeof...(Tp) - I + 1>(network).Delta());
     
-    std::get<sizeof...(Tp) - I>(t).Backward(
-        std::get<sizeof...(Tp) - I>(t).OutputParameter(),
-        std::get<sizeof...(Tp) - I + 1>(t).Delta(),
-        std::get<sizeof...(Tp) - I>(t).Delta());
+    std::get<sizeof...(Tp) - I>(network).Backward(
+        std::get<sizeof...(Tp) - I>(network).OutputParameter(),
+        std::get<sizeof...(Tp) - I + 1>(network).Delta(),
+        std::get<sizeof...(Tp) - I>(network).Delta());
 
-    BackwardTail<I + 1, DataType, Tp...>(error, t);
+    BackwardTail<I + 1, DataType, Tp...>(error, network);
   }
 
   /*
@@ -600,10 +610,10 @@ class RNN
   template<typename T, typename P, typename D>
   typename std::enable_if<
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
-  BackwardRecurrent(T& t, P& /* unused */, D& delta)
+  BackwardRecurrent(T& layer, P& /* unused */, D& delta)
   {
-    if (!t.Delta().is_empty())
-      delta += t.Delta();
+    if (!layer.Delta().is_empty())
+      delta += layer.Delta();
   }
 
   template<typename T, typename P, typename D>
@@ -617,31 +627,29 @@ class RNN
   /**
    * Iterate through all layer modules and update the the gradient using the
    * layer defined optimizer.
-   *
-   * enable_if (SFINAE) is used to iterate through the network layer.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 0, size_t Max = std::tuple_size<LayerTypes>::value - 2,
       typename... Tp>
   typename std::enable_if<I == Max, void>::type
-  UpdateGradients(std::tuple<Tp...>& t)
+  UpdateGradients(std::tuple<Tp...>& network)
   {
-    Update(std::get<I>(t), std::get<I>(t).OutputParameter(),
-        std::get<I + 1>(t).Delta(), std::get<I + 1>(t),
-        std::get<I + 1>(t).InputParameter(), std::get<I + 1>(t).Delta());
+    Update(std::get<I>(network), std::get<I>(network).OutputParameter(),
+        std::get<I + 1>(network).Delta(), std::get<I + 1>(network),
+        std::get<I + 1>(network).InputParameter(),
+        std::get<I + 1>(network).Delta());
   }
 
   template<size_t I = 0, size_t Max = std::tuple_size<LayerTypes>::value - 2,
       typename... Tp>
   typename std::enable_if<I < Max, void>::type
-  UpdateGradients(std::tuple<Tp...>& t)
+  UpdateGradients(std::tuple<Tp...>& network)
   {
-    Update(std::get<I>(t), std::get<I>(t).OutputParameter(),
-        std::get<I + 1>(t).Delta(), std::get<I + 1>(t),
-        std::get<I + 1>(t).InputParameter(), std::get<I + 2>(t).Delta());
+    Update(std::get<I>(network), std::get<I>(network).OutputParameter(),
+        std::get<I + 1>(network).Delta(), std::get<I + 1>(network),
+        std::get<I + 1>(network).InputParameter(),
+        std::get<I + 2>(network).Delta());
 
-    UpdateGradients<I + 1, Max, Tp...>(t);
+    UpdateGradients<I + 1, Max, Tp...>(network);
   }
 
   template<typename T1, typename P1, typename D1, typename T2, typename P2,
@@ -649,11 +657,10 @@ class RNN
   typename std::enable_if<
       HasGradientCheck<T1, void(T1::*)(const D1&, P1&)>::value &&
       HasRecurrentParameterCheck<T2, P2&(T2::*)()>::value, void>::type
-  Update(T1& t1, P1& /* unused */, D1& /* unused */, T2& /* unused */,
+  Update(T1& layer, P1& /* unused */, D1& /* unused */, T2& /* unused */,
          P2& /* unused */, D2& delta2)
   {
-    t1.Gradient(delta2, t1.Gradient());
-    t1.Optimizer().Update();
+    layer.Gradient(delta2, layer.Gradient());
   }
 
   template<typename T1, typename P1, typename D1, typename T2, typename P2,
@@ -674,78 +681,43 @@ class RNN
   typename std::enable_if<
       HasGradientCheck<T1, void(T1::*)(const D1&, P1&)>::value &&
       !HasRecurrentParameterCheck<T2, P2&(T2::*)()>::value, void>::type
-  Update(T1& t1, P1& /* unused */, D1& delta1, T2& /* unused */,
+  Update(T1& layer, P1& /* unused */, D1& delta1, T2& /* unused */,
          P2& /* unused */, D2& /* unused */)
   {
-    t1.Gradient(delta1, t1.Gradient());
-    t1.Optimizer().Update();
-  }
-
-  /**
-   * Update the weights using the calulated gradients.
-   *
-   * enable_if (SFINAE) is used to iterate through the network layer.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
-   */
-  template<size_t I = 0, size_t Max = std::tuple_size<LayerTypes>::value - 1,
-      typename... Tp>
-  typename std::enable_if<I == Max, void>::type
-  ApplyGradients(std::tuple<Tp...>& /* unused */)
-  {
-    /* Nothing to do here */
-  }
-
-  template<size_t I = 0, size_t Max = std::tuple_size<LayerTypes>::value - 1,
-      typename... Tp>
-  typename std::enable_if<I < Max, void>::type
-  ApplyGradients(std::tuple<Tp...>& t)
-  {
-    Apply(std::get<I>(t), std::get<I>(t).OutputParameter(),
-          std::get<I + 1>(t).Delta());
-
-    ApplyGradients<I + 1, Max, Tp...>(t);
-  }
-
-  template<typename T, typename P, typename D>
-  typename std::enable_if<
-      HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
-  Apply(T& t, P& /* unused */, D& /* unused */)
-  {
-    t.Optimizer().Optimize();
-    t.Optimizer().Reset();
-  }
-
-  template<typename T, typename P, typename D>
-  typename std::enable_if<
-      !HasGradientCheck<T, void(T::*)(const P&, D&)>::value, void>::type
-  Apply(T& /* unused */, P& /* unused */, D& /* unused */)
-  {
-    /* Nothing to do here */
+    layer.Gradient(delta1, layer.Gradient());
   }
 
   /*
    * Calculate and store the output activation.
    */
   template<typename DataType, typename... Tp>
-  void OutputPrediction(DataType& output, std::tuple<Tp...>& t)
+  void OutputPrediction(DataType& output, std::tuple<Tp...>& network)
   {
     // Calculate and store the output prediction.
-    outputLayer.OutputClass(std::get<sizeof...(Tp) - 1>(t).OutputParameter(),
-        output);
+    outputLayer.OutputClass(std::get<sizeof...(Tp) - 1>(
+        network).OutputParameter(), output);
   }
 
-  //! The layer modules used to build the network.
+  //! Instantiated recurrent neural network.
   LayerTypes network;
 
   //! The outputlayer used to evaluate the network
   OutputLayerType& outputLayer;
 
   //! Performance strategy used to claculate the error.
-  PerformanceFunction performanceFunction;
+  PerformanceFunction performanceFunc;
+
+  //! The current evaluation mode (training or testing).
+  bool deterministic;
+
+  //! Matrix of (trained) parameters.
+  arma::mat parameter;
 
-  //! The current training error of the network.
-  double trainError;
+  //! The matrix of data points (predictors).
+  arma::mat predictors;
+
+  //! The matrix of responses to the input data points.
+  arma::mat responses;
 
   //! Locally stored network input size.
   size_t inputSize;
@@ -753,9 +725,6 @@ class RNN
   //! Locally stored network output size.
   size_t outputSize;
 
-  //! The current evaluation mode (training or testing).
-  bool deterministic;
-
   //! The index of the current sequence number.
   size_t seqNum;
 
@@ -766,26 +735,19 @@ class RNN
   bool seqOutput;
 
   //! The activation storage we are using to perform the feed backward pass.
-  boost::ptr_vector<MatType> activations;
-}; // class RNN
+  boost::ptr_vector<arma::mat> activations;
 
-//! Network traits for the RNN network.
-template <
-  typename LayerTypes,
-  typename OutputLayerType,
-  class PerformanceFunction
->
-class NetworkTraits<
-    RNN<LayerTypes, OutputLayerType, PerformanceFunction> >
-{
- public:
-  static const bool IsFNN = false;
-  static const bool IsRNN = true;
-  static const bool IsCNN = false;
-  static const bool IsSAE = false;
-};
+  //! The number of separable functions (the number of predictor points).
+  size_t numFunctions;
+
+  //! Locally stored backward error.
+  arma::mat error;
+}; // class RNN
 
 } // namespace ann
 } // namespace mlpack
 
+// Include implementation.
+#include "rnn_impl.hpp"
+
 #endif
diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp
similarity index 64%
copy from src/mlpack/methods/ann/ffn_impl.hpp
copy to src/mlpack/methods/ann/rnn_impl.hpp
index bd7436a..0104451 100644
--- a/src/mlpack/methods/ann/ffn_impl.hpp
+++ b/src/mlpack/methods/ann/rnn_impl.hpp
@@ -1,14 +1,14 @@
 /**
- * @file ffn_impl.hpp
+ * @file rnn_impl.hpp
  * @author Marcus Edel
  *
- * Definition of the FFN class, which implements feed forward neural networks.
+ * Definition of the RNN class, which implements recurrent neural networks.
  */
-#ifndef __MLPACK_METHODS_ANN_FFN_IMPL_HPP
-#define __MLPACK_METHODS_ANN_FFN_IMPL_HPP
+#ifndef __MLPACK_METHODS_ANN_RNN_IMPL_HPP
+#define __MLPACK_METHODS_ANN_RNN_IMPL_HPP
 
 // In case it hasn't been included yet.
-#include "ffn.hpp"
+#include "rnn.hpp"
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -23,8 +23,8 @@ template<typename LayerType,
          typename OutputType,
          template<typename> class OptimizerType
 >
-FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
->::FFN(LayerType &&network,
+RNN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::RNN(LayerType &&network,
        OutputType &&outputLayer,
        const arma::mat& predictors,
        const arma::mat& responses,
@@ -36,7 +36,9 @@ FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
     performanceFunc(std::move(performanceFunction)),
     predictors(predictors),
     responses(responses),
-    numFunctions(predictors.n_cols)
+    numFunctions(predictors.n_cols),
+    inputSize(0),
+    outputSize(0)
 {
   static_assert(std::is_same<typename std::decay<LayerType>::type,
                   LayerTypes>::value,
@@ -50,11 +52,11 @@ FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
   NetworkWeights(parameter, network);
 
   // Train the model.
-  Timer::Start("ffn_optimization");
+  Timer::Start("rnn_optimization");
   const double out = optimizer.Optimize(parameter);
-  Timer::Stop("ffn_optimization");
+  Timer::Stop("rnn_optimization");
 
-  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+  Log::Info << "RNN::RNN(): final objective of trained model is " << out
       << "." << std::endl;
 }
 
@@ -64,8 +66,8 @@ template<typename LayerTypes,
          typename PerformanceFunction
 >
 template<typename LayerType, typename OutputType>
-FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
->::FFN(LayerType &&network,
+RNN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::RNN(LayerType &&network,
        OutputType &&outputLayer,
        const arma::mat& predictors,
        const arma::mat& responses,
@@ -73,7 +75,9 @@ FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
        PerformanceFunction performanceFunction) : 
     network(std::forward<LayerType>(network)),
     outputLayer(std::forward<OutputType>(outputLayer)),
-    performanceFunc(std::move(performanceFunction))
+    performanceFunc(std::move(performanceFunction)),
+    inputSize(0),
+    outputSize(0)
 {
   static_assert(std::is_same<typename std::decay<LayerType>::type,
                   LayerTypes>::value,
@@ -95,14 +99,16 @@ template<typename LayerTypes,
          typename PerformanceFunction
 >
 template<typename LayerType, typename OutputType>
-FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
->::FFN(LayerType &&network,
+RNN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::RNN(LayerType &&network,
        OutputType &&outputLayer,
        InitializationRuleType initializeRule,
        PerformanceFunction performanceFunction) : 
     network(std::forward<LayerType>(network)),
     outputLayer(std::forward<OutputType>(outputLayer)),
-    performanceFunc(std::move(performanceFunction))
+    performanceFunc(std::move(performanceFunction)),
+    inputSize(0),
+    outputSize(0)
 {
   static_assert(std::is_same<typename std::decay<LayerType>::type,
                   LayerTypes>::value,
@@ -114,8 +120,6 @@ FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 
   initializeRule.Initialize(parameter, NetworkSize(network), 1);
   NetworkWeights(parameter, network);
-
-  Log::Debug << parameter << std::endl;
 }
 
 template<typename LayerTypes,
@@ -124,7 +128,7 @@ template<typename LayerTypes,
          typename PerformanceFunction
 >
 template<template<typename> class OptimizerType>
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Train(const arma::mat& predictors, const arma::mat& responses)
 {
@@ -135,11 +139,11 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
   OptimizerType<decltype(*this)> optimizer(*this);
 
   // Train the model.
-  Timer::Start("ffn_optimization");
+  Timer::Start("rnn_optimization");
   const double out = optimizer.Optimize(parameter);
-  Timer::Stop("ffn_optimization");
+  Timer::Stop("rnn_optimization");
 
-  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+  Log::Info << "RNN::RNN(): final objective of trained model is " << out
       << "." << std::endl;
 }
 
@@ -149,7 +153,7 @@ template<typename LayerTypes,
          typename PerformanceFunction
 >
 template<template<typename> class OptimizerType>
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Train(const arma::mat& predictors,
          const arma::mat& responses,
@@ -160,11 +164,11 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
   this->responses = responses;
 
   // Train the model.
-  Timer::Start("ffn_optimization");
+  Timer::Start("rnn_optimization");
   const double out = optimizer.Optimize(parameter);
-  Timer::Stop("ffn_optimization");
+  Timer::Stop("rnn_optimization");
 
-  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+  Log::Info << "RNN::RNN(): final objective of trained model is " << out
       << "." << std::endl;
 }
 
@@ -176,16 +180,16 @@ template<typename LayerTypes,
 template<
     template<typename> class OptimizerType
 >
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Train(OptimizerType<NetworkType>& optimizer)
 {
   // Train the model.
-  Timer::Start("ffn_optimization");
+  Timer::Start("rnn_optimization");
   const double out = optimizer.Optimize(parameter);
-  Timer::Stop("ffn_optimization");
+  Timer::Stop("rnn_optimization");
 
-  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+  Log::Info << "RNN::RNN(): final objective of trained model is " << out
       << "." << std::endl;
 }
 
@@ -194,29 +198,21 @@ template<typename LayerTypes,
          typename InitializationRuleType,
          typename PerformanceFunction
 >
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Predict(arma::mat& predictors, arma::mat& responses)
 {
-  deterministic = true;
-
   arma::mat responsesTemp;
-  ResetParameter(network);
-  Forward(arma::mat(predictors.colptr(0), predictors.n_rows, 1, false, true),
-      network);
-  OutputPrediction(responsesTemp, network);
+  SinglePredict(arma::mat(predictors.colptr(0), predictors.n_rows,
+      1, false, true), responsesTemp);
 
   responses = arma::mat(responsesTemp.n_elem, predictors.n_cols);
   responses.col(0) = responsesTemp.col(0);
 
   for (size_t i = 1; i < predictors.n_cols; i++)
   {
-    Forward(arma::mat(predictors.colptr(i), predictors.n_rows, 1, false, true),
-        network);
-
-    responsesTemp = arma::mat(responses.colptr(i), responses.n_rows, 1, false,
-        true);
-    OutputPrediction(responsesTemp, network);
+    SinglePredict(arma::mat(predictors.colptr(i), predictors.n_rows,
+      1, false, true), responsesTemp);
     responses.col(i) = responsesTemp.col(0);
   }
 }
@@ -226,7 +222,7 @@ template<typename LayerTypes,
          typename InitializationRuleType,
          typename PerformanceFunction
 >
-double FFN<
+double RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Evaluate(const arma::mat& /* unused */,
             const size_t i,
@@ -234,13 +230,44 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 {
   this->deterministic = deterministic;
 
+  arma::mat input = arma::mat(predictors.colptr(i), predictors.n_rows,
+      1, false, true);
+  arma::mat target = arma::mat(responses.colptr(i), responses.n_rows,
+      1, false, true);
+
+  // Initialize the activation storage only once.
+  if (activations.empty())
+    InitLayer(input, target, network);
+
+  double networkError = 0;
+  seqLen = input.n_rows / inputSize;
   ResetParameter(network);
 
-  Forward(arma::mat(predictors.colptr(i), predictors.n_rows, 1, false, true),
-      network);
+  error = arma::mat(outputSize, outputSize < target.n_elem ? seqLen : 1);
 
-  return OutputError(arma::mat(responses.colptr(i), responses.n_rows, 1, false,
-      true), error, network);
+  // Iterate through the input sequence and perform the feed forward pass.
+  for (seqNum = 0; seqNum < seqLen; seqNum++)
+  {
+    // Perform the forward pass and save the activations.
+    Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
+        network);
+    SaveActivations(network);
+
+    // Retrieve output error of the subsequence.
+    if (seqOutput)
+    {
+      arma::mat seqError = error.unsafe_col(seqNum);
+      arma::mat seqTarget = target.submat(seqNum * outputSize, 0,
+          (seqNum + 1) * outputSize - 1, 0);
+      networkError += OutputError(seqTarget, seqError, network);
+    }
+  }
+
+  // Retrieve output error of the complete sequence.
+  if (!seqOutput)
+    return OutputError(target, error, network);
+
+  return networkError;
 }
 
 template<typename LayerTypes,
@@ -248,16 +275,46 @@ template<typename LayerTypes,
          typename InitializationRuleType,
          typename PerformanceFunction
 >
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Gradient(const arma::mat& /* unused */,
-            const size_t /* unused */,
+            const size_t i,
             arma::mat& gradient)
 {
-  NetworkGradients(gradient, network);
+  gradient.zeros();
+  arma::mat currentGradient = arma::mat(gradient.n_rows, gradient.n_cols);
+  NetworkGradients(currentGradient, network);
+
+  const arma::mat input = arma::mat(predictors.colptr(i), predictors.n_rows,
+      1, false, true);
 
-  Backward<>(error, network);
-  UpdateGradients<>(network);
+  // Iterate through the input sequence and perform the feed backward pass.
+  for (seqNum = seqLen - 1; seqNum >= 0; seqNum--)
+  {
+    // Load the network activation for the upcoming backward pass.
+    LoadActivations(input.rows(seqNum * inputSize, (seqNum + 1) *
+        inputSize - 1), network);
+
+    // Perform the backward pass.
+    if (seqOutput)
+    {
+      arma::mat seqError = error.unsafe_col(seqNum);
+      Backward(seqError, network);
+    }
+    else
+    {
+      Backward(error, network);
+    }
+    
+    // Link the parameters and update the gradients.
+    LinkParameter(network);
+    UpdateGradients<>(network);
+
+    // Update the overall gradient.
+    gradient += currentGradient;
+
+    if (seqNum == 0) break;
+  }
 }
 
 template<typename LayerTypes,
@@ -266,7 +323,7 @@ template<typename LayerTypes,
          typename PerformanceFunction
 >
 template<typename Archive>
-void FFN<
+void RNN<
 LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
 >::Serialize(Archive& ar, const unsigned int /* version */)
 {




More information about the mlpack-git mailing list