[mlpack-git] master: Reset the recurrent delta at the beginning of a new sequence. (a2b7c9b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:55 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3

>---------------------------------------------------------------

commit a2b7c9b45de3fa29293724fd3a39a7d7e50a28cb
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Sun Nov 8 20:20:39 2015 +0100

    Reset the recurrent delta at the beginning of a new sequence.


>---------------------------------------------------------------

a2b7c9b45de3fa29293724fd3a39a7d7e50a28cb
 src/mlpack/methods/ann/rnn.hpp | 30 +++++++++++++++++++++++++++---
 1 file changed, 27 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 0c58097..7ffb72b 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -100,7 +100,6 @@ class RNN
       LinkParameter(network);
       UpdateGradients<>(network);
 
-
       if (seqNum == 0) break;
     }
   }
@@ -229,6 +228,9 @@ class RNN
   {
     ResetDeterministic(std::get<I>(t));
     ResetSeqLen(std::get<I>(t));
+    ResetRecurrent(std::get<I>(t), std::get<I>(t).InputParameter());
+    std::get<I>(t).Delta().zeros();
+
     ResetParameter<I + 1, Tp...>(t);
   }
 
@@ -267,6 +269,26 @@ class RNN
   ResetSeqLen(T& /* unused */) { /* Nothing to do here */ }
 
   /**
+   * Distinguish between recurrent layer and non-recurrent layer when resetting
+   * the recurrent parameter.
+   */
+  template<typename T, typename P>
+  typename std::enable_if<
+      HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
+  ResetRecurrent(T& t, P& /* unused */)
+  {
+    t.RecurrentParameter().zeros();
+  }
+
+  template<typename T, typename P>
+  typename std::enable_if<
+      !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
+  ResetRecurrent(T& /* unused */, P& /* unused */)
+  {
+    /* Nothing to do here */
+  }
+
+  /**
    * Initialize the network by setting the input size and output size.
    *
    * enable_if (SFINAE) is used to iterate through the network. The general
@@ -549,8 +571,10 @@ class RNN
 
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I == (sizeof...(Tp)), void>::type
-  BackwardTail(const DataType& /* unused */,
-               std::tuple<Tp...>& /* unused */) {  /* Nothing to do here */ }
+  BackwardTail(const DataType& /* unused */, std::tuple<Tp...>& /* unused */)
+  {
+    /* Nothing to do here */
+  }
 
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp)), void>::type



More information about the mlpack-git mailing list