[mlpack-git] master: Set the sequence length if necessary. (605697f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:46 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3

>---------------------------------------------------------------

commit 605697fbb7ec637c720774414d05f60e959ea7b7
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Sat Nov 7 15:57:18 2015 +0100

    Set the sequence length if necessary.


>---------------------------------------------------------------

605697fbb7ec637c720774414d05f60e959ea7b7
 src/mlpack/methods/ann/rnn.hpp | 76 ++++++++++++++++++++++++------------------
 1 file changed, 44 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 166f0f3..0c58097 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -79,19 +79,16 @@ class RNN
   void FeedBackward(const InputType& input, const ErrorType& error)
   {
     // Iterate through the input sequence and perform the feed backward pass.
-    for (seqNum = 0; seqNum < input.n_elem; seqNum += inputSize)
+    for (seqNum = seqLen - 1; seqNum >= 0; seqNum--)
     {
       // Load the network activation for the upcoming backward pass.
-      if (seqNum > 0)
-      {
-        LoadActivations(input.rows(input.n_elem - seqNum - 1,
-            (input.n_elem - seqNum - 1 + inputSize) - 1), network);
-      }
+        LoadActivations(input.rows(seqNum * inputSize, (seqNum + 1) *
+            inputSize - 1), network);
 
       // Perform the backward pass.
       if (seqOutput)
       {
-        ErrorType seqError = error.unsafe_col(error.n_cols - seqNum - 1);
+        ErrorType seqError = error.unsafe_col(seqNum);
         Backward(seqError, network);
       }
       else
@@ -102,6 +99,9 @@ class RNN
       // Link the parameters and update the gradients.
       LinkParameter(network);
       UpdateGradients<>(network);
+
+
+      if (seqNum == 0) break;
     }
   }
 
@@ -127,15 +127,15 @@ class RNN
   void Predict(const DataType& input, DataType& output)
   {
     deterministic = true;
-    ResetParameter(network);
-
     seqLen = input.n_rows / inputSize;
+    ResetParameter(network);
 
     // Iterate through the input sequence and perform the feed forward pass.
-    for (seqNum = 0; seqNum < input.n_elem; seqNum += inputSize)
+    for (seqNum = 0; seqNum < seqLen; seqNum++)
     {
       // Perform the forward pass and save the activations.
-      Forward(input.rows(seqNum, (seqNum + inputSize) - 1), network);
+      Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
+          network);
       SaveActivations(network);
 
       // Retrieve output of the subsequence.
@@ -170,29 +170,27 @@ class RNN
       InitLayer(input, target, network);
 
     double networkError = 0;
+    seqLen = input.n_rows / inputSize;
     deterministic = false;
     ResetParameter(network);
 
-    seqLen = input.n_rows / inputSize;   
     error = ErrorType(outputSize, outputSize < target.n_elem ? seqLen : 1);
 
     // Iterate through the input sequence and perform the feed forward pass.
-    for (seqNum = 0, seqOutputNum = 0; seqNum < input.n_elem;
-        seqNum += inputSize)
+    for (seqNum = 0; seqNum < seqLen; seqNum++)
     {
       // Perform the forward pass and save the activations.
-      Forward(input.rows(seqNum, (seqNum + inputSize) - 1), network);
+      Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
+          network);
       SaveActivations(network);
 
       // Retrieve output error of the subsequence.
       if (seqOutput)
       {
         arma::mat seqError = error.unsafe_col(seqNum);
-        arma::mat seqTarget = target.submat(seqOutputNum, 0,
-            seqOutputNum + outputSize - 1, 0);
-
+        arma::mat seqTarget = target.submat(seqNum * outputSize, 0,
+            (seqNum + 1) * outputSize - 1, 0);
         networkError += OutputError(seqTarget, seqError, network);
-        seqOutputNum += outputSize;
       }
     }
 
@@ -230,6 +228,7 @@ class RNN
   ResetParameter(std::tuple<Tp...>& t)
   {
     ResetDeterministic(std::get<I>(t));
+    ResetSeqLen(std::get<I>(t));
     ResetParameter<I + 1, Tp...>(t);
   }
 
@@ -251,6 +250,23 @@ class RNN
   ResetDeterministic(T& /* unused */) { /* Nothing to do here */ }
 
   /**
+   * Reset the layer sequence length by setting the current seqLen parameter
+   * for all layer that implement the SeqLen function.
+   */
+  template<typename T>
+  typename std::enable_if<
+      HasSeqLenCheck<T, size_t&(T::*)(void)>::value, void>::type
+  ResetSeqLen(T& t)
+  {
+    t.SeqLen() = seqLen;
+  }
+
+  template<typename T>
+  typename std::enable_if<
+      !HasSeqLenCheck<T, size_t&(T::*)(void)>::value, void>::type
+  ResetSeqLen(T& /* unused */) { /* Nothing to do here */ }
+
+  /**
    * Initialize the network by setting the input size and output size.
    *
    * enable_if (SFINAE) is used to iterate through the network. The general
@@ -322,12 +338,6 @@ class RNN
   typename std::enable_if<I < sizeof...(Tp), void>::type
   SaveActivations(std::tuple<Tp...>& t)
   {
-    if (activations.size() == I)
-    {
-      activations.push_back(new MatType(
-          std::get<I>(t).OutputParameter().n_rows, seqLen));
-    }
-
     Save(I, std::get<I>(t), std::get<I>(t).InputParameter());
     SaveActivations<I + 1, Tp...>(t);
   }
@@ -341,6 +351,9 @@ class RNN
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
   Save(const size_t layerNumber, T& t, P& /* unused */)
   {
+    if (activations.size() == layerNumber)
+      activations.push_back(new MatType(t.RecurrentParameter().n_rows, seqLen));
+
     activations[layerNumber].unsafe_col(seqNum) = t.RecurrentParameter();
   }
 
@@ -349,6 +362,9 @@ class RNN
       !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
   Save(const size_t layerNumber, T& t, P& /* unused */)
   {
+    if (activations.size() == layerNumber)
+      activations.push_back(new MatType(t.OutputParameter().n_rows, seqLen));
+
     activations[layerNumber].unsafe_col(seqNum) = t.OutputParameter();
   }
 
@@ -383,8 +399,7 @@ class RNN
       HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
   Load(const size_t layerNumber, T& t, P& /* unused */)
   {
-    t.RecurrentParameter() = activations[layerNumber].unsafe_col(
-        seqLen - 1 - seqNum);
+    t.RecurrentParameter() = activations[layerNumber].unsafe_col(seqNum);
   }
 
   template<typename T, typename P>
@@ -392,8 +407,7 @@ class RNN
       !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
   Load(const size_t layerNumber, T& t, P& /* unused */)
   {
-    t.OutputParameter() = activations[layerNumber].unsafe_col(
-        seqLen - 1 - seqNum);
+    t.OutputParameter() = activations[layerNumber].unsafe_col(seqNum);
   }
 
   /**
@@ -491,6 +505,7 @@ class RNN
       !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
   UpdateRecurrent(T& /* unused */, P& /* unused */, D& /* unused */)
   {
+    /* Nothing to do here */
   }
 
   /*
@@ -715,9 +730,6 @@ class RNN
   //! The index of the current sequence number.
   size_t seqNum;
 
-  //! The index of the current sequence target number.
-  size_t seqOutputNum;
-
   //! Locally stored number of samples in one input sequence.
   size_t seqLen;
 



More information about the mlpack-git mailing list