[mlpack-git] master: Reset the recurrent delta at the beginning of a new sequence. (a2b7c9b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:55 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3
>---------------------------------------------------------------
commit a2b7c9b45de3fa29293724fd3a39a7d7e50a28cb
Author: marcus <marcus.edel at fu-berlin.de>
Date: Sun Nov 8 20:20:39 2015 +0100
Reset the recurrent delta at the beginning of a new sequence.
>---------------------------------------------------------------
a2b7c9b45de3fa29293724fd3a39a7d7e50a28cb
src/mlpack/methods/ann/rnn.hpp | 30 +++++++++++++++++++++++++++---
1 file changed, 27 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 0c58097..7ffb72b 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -100,7 +100,6 @@ class RNN
LinkParameter(network);
UpdateGradients<>(network);
-
if (seqNum == 0) break;
}
}
@@ -229,6 +228,9 @@ class RNN
{
ResetDeterministic(std::get<I>(t));
ResetSeqLen(std::get<I>(t));
+ ResetRecurrent(std::get<I>(t), std::get<I>(t).InputParameter());
+ std::get<I>(t).Delta().zeros();
+
ResetParameter<I + 1, Tp...>(t);
}
@@ -267,6 +269,26 @@ class RNN
ResetSeqLen(T& /* unused */) { /* Nothing to do here */ }
/**
+ * Distinguish between recurrent layer and non-recurrent layer when resetting
+ * the recurrent parameter.
+ */
+ template<typename T, typename P>
+ typename std::enable_if<
+ HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
+ ResetRecurrent(T& t, P& /* unused */)
+ {
+ t.RecurrentParameter().zeros();
+ }
+
+ template<typename T, typename P>
+ typename std::enable_if<
+ !HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
+ ResetRecurrent(T& /* unused */, P& /* unused */)
+ {
+ /* Nothing to do here */
+ }
+
+ /**
* Initialize the network by setting the input size and output size.
*
* enable_if (SFINAE) is used to iterate through the network. The general
@@ -549,8 +571,10 @@ class RNN
template<size_t I = 1, typename DataType, typename... Tp>
typename std::enable_if<I == (sizeof...(Tp)), void>::type
- BackwardTail(const DataType& /* unused */,
- std::tuple<Tp...>& /* unused */) { /* Nothing to do here */ }
+ BackwardTail(const DataType& /* unused */, std::tuple<Tp...>& /* unused */)
+ {
+ /* Nothing to do here */
+ }
template<size_t I = 1, typename DataType, typename... Tp>
typename std::enable_if<I < (sizeof...(Tp)), void>::type
More information about the mlpack-git
mailing list