[mlpack-git] master: Set the sequence length if necessary. (605697f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:46 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3
>---------------------------------------------------------------
commit 605697fbb7ec637c720774414d05f60e959ea7b7
Author: marcus <marcus.edel at fu-berlin.de>
Date: Sat Nov 7 15:57:18 2015 +0100
Set the sequence length if necessary.
>---------------------------------------------------------------
605697fbb7ec637c720774414d05f60e959ea7b7
src/mlpack/methods/ann/rnn.hpp | 76 ++++++++++++++++++++++++------------------
1 file changed, 44 insertions(+), 32 deletions(-)
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 166f0f3..0c58097 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -79,19 +79,16 @@ class RNN
void FeedBackward(const InputType& input, const ErrorType& error)
{
// Iterate through the input sequence and perform the feed backward pass.
- for (seqNum = 0; seqNum < input.n_elem; seqNum += inputSize)
+ for (seqNum = seqLen - 1; seqNum >= 0; seqNum--)
{
// Load the network activation for the upcoming backward pass.
- if (seqNum > 0)
- {
- LoadActivations(input.rows(input.n_elem - seqNum - 1,
- (input.n_elem - seqNum - 1 + inputSize) - 1), network);
- }
+ LoadActivations(input.rows(seqNum * inputSize, (seqNum + 1) *
+ inputSize - 1), network);
// Perform the backward pass.
if (seqOutput)
{
- ErrorType seqError = error.unsafe_col(error.n_cols - seqNum - 1);
+ ErrorType seqError = error.unsafe_col(seqNum);
Backward(seqError, network);
}
else
@@ -102,6 +99,9 @@ class RNN
// Link the parameters and update the gradients.
LinkParameter(network);
UpdateGradients<>(network);
+
+
+ if (seqNum == 0) break;
}
}
@@ -127,15 +127,15 @@ class RNN
void Predict(const DataType& input, DataType& output)
{
deterministic = true;
- ResetParameter(network);
-
seqLen = input.n_rows / inputSize;
+ ResetParameter(network);
// Iterate through the input sequence and perform the feed forward pass.
- for (seqNum = 0; seqNum < input.n_elem; seqNum += inputSize)
+ for (seqNum = 0; seqNum < seqLen; seqNum++)
{
// Perform the forward pass and save the activations.
- Forward(input.rows(seqNum, (seqNum + inputSize) - 1), network);
+ Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
+ network);
SaveActivations(network);
// Retrieve output of the subsequence.
@@ -170,29 +170,27 @@ class RNN
InitLayer(input, target, network);
double networkError = 0;
+ seqLen = input.n_rows / inputSize;
deterministic = false;
ResetParameter(network);
- seqLen = input.n_rows / inputSize;
error = ErrorType(outputSize, outputSize < target.n_elem ? seqLen : 1);
// Iterate through the input sequence and perform the feed forward pass.
- for (seqNum = 0, seqOutputNum = 0; seqNum < input.n_elem;
- seqNum += inputSize)
+ for (seqNum = 0; seqNum < seqLen; seqNum++)
{
// Perform the forward pass and save the activations.
- Forward(input.rows(seqNum, (seqNum + inputSize) - 1), network);
+ Forward(input.rows(seqNum * inputSize, (seqNum + 1) * inputSize - 1),
+ network);
SaveActivations(network);
// Retrieve output error of the subsequence.
if (seqOutput)
{
arma::mat seqError = error.unsafe_col(seqNum);
- arma::mat seqTarget = target.submat(seqOutputNum, 0,
- seqOutputNum + outputSize - 1, 0);
-
+ arma::mat seqTarget = target.submat(seqNum * outputSize, 0,
+ (seqNum + 1) * outputSize - 1, 0);
networkError += OutputError(seqTarget, seqError, network);
- seqOutputNum += outputSize;
}
}
@@ -230,6 +228,7 @@ class RNN
ResetParameter(std::tuple<Tp...>& t)
{
ResetDeterministic(std::get<I>(t));
+ ResetSeqLen(std::get<I>(t));
ResetParameter<I + 1, Tp...>(t);
}
@@ -251,6 +250,23 @@ class RNN
ResetDeterministic(T& /* unused */) { /* Nothing to do here */ }
/**
+ * Reset the layer sequence length by setting the current seqLen parameter
+ * for all layer that implement the SeqLen function.
+ */
+ template<typename T>
+ typename std::enable_if<
+ HasSeqLenCheck<T, size_t&(T::*)(void)>::value, void>::type
+ ResetSeqLen(T& t)
+ {
+ t.SeqLen() = seqLen;
+ }
+
+ template<typename T>
+ typename std::enable_if<
+ !HasSeqLenCheck<T, size_t&(T::*)(void)>::value, void>::type
+ ResetSeqLen(T& /* unused */) { /* Nothing to do here */ }
+
+ /**
* Initialize the network by setting the input size and output size.
*
* enable_if (SFINAE) is used to iterate through the network. The general
@@ -322,12 +338,6 @@ class RNN
typename std::enable_if<I < sizeof...(Tp), void>::type
SaveActivations(std::tuple<Tp...>& t)
{
- if (activations.size() == I)
- {
- activations.push_back(new MatType(
- std::get<I>(t).OutputParameter().n_rows, seqLen));
- }
-
Save(I, std::get<I>(t), std::get<I>(t).InputParameter());
SaveActivations<I + 1, Tp...>(t);
}
@@ -341,6 +351,9 @@ class RNN
HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
Save(const size_t layerNumber, T& t, P& /* unused */)
{
+ if (activations.size() == layerNumber)
+ activations.push_back(new MatType(t.RecurrentParameter().n_rows, seqLen));
+
activations[layerNumber].unsafe_col(seqNum) = t.RecurrentParameter();
}
@@ -349,6 +362,9 @@ class RNN
!HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
Save(const size_t layerNumber, T& t, P& /* unused */)
{
+ if (activations.size() == layerNumber)
+ activations.push_back(new MatType(t.OutputParameter().n_rows, seqLen));
+
activations[layerNumber].unsafe_col(seqNum) = t.OutputParameter();
}
@@ -383,8 +399,7 @@ class RNN
HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
Load(const size_t layerNumber, T& t, P& /* unused */)
{
- t.RecurrentParameter() = activations[layerNumber].unsafe_col(
- seqLen - 1 - seqNum);
+ t.RecurrentParameter() = activations[layerNumber].unsafe_col(seqNum);
}
template<typename T, typename P>
@@ -392,8 +407,7 @@ class RNN
!HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
Load(const size_t layerNumber, T& t, P& /* unused */)
{
- t.OutputParameter() = activations[layerNumber].unsafe_col(
- seqLen - 1 - seqNum);
+ t.OutputParameter() = activations[layerNumber].unsafe_col(seqNum);
}
/**
@@ -491,6 +505,7 @@ class RNN
!HasRecurrentParameterCheck<T, P&(T::*)()>::value, void>::type
UpdateRecurrent(T& /* unused */, P& /* unused */, D& /* unused */)
{
+ /* Nothing to do here */
}
/*
@@ -715,9 +730,6 @@ class RNN
//! The index of the current sequence number.
size_t seqNum;
- //! The index of the current sequence target number.
- size_t seqOutputNum;
-
//! Locally stored number of samples in one input sequence.
size_t seqLen;
More information about the mlpack-git
mailing list