[mlpack-git] master: Adjust the rnn class and LSTM layer to handle sequences of different lengths. (757c92a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat Mar 7 08:10:13 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3e4a3c8ba42e113e0cdebd73bbfa1f6dea9d7010...757c92a1596ef28f5bc924fbec031fb24b98c781
>---------------------------------------------------------------
commit 757c92a1596ef28f5bc924fbec031fb24b98c781
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Mar 7 14:10:06 2015 +0100
Adjust the rnn class and LSTM layer to handle sequences of different lengths.
>---------------------------------------------------------------
757c92a1596ef28f5bc924fbec031fb24b98c781
src/mlpack/core/util/arma_traits.hpp | 0
src/mlpack/methods/ann/layer/bias_layer.hpp | 5 +-
.../ann/layer/binary_classification_layer.hpp | 6 +--
src/mlpack/methods/ann/layer/layer_traits.hpp | 5 ++
src/mlpack/methods/ann/layer/lstm_layer.hpp | 56 +++++++++++++++++++++-
src/mlpack/methods/ann/rnn.hpp | 34 ++++++++++++-
6 files changed, 96 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/methods/ann/layer/bias_layer.hpp b/src/mlpack/methods/ann/layer/bias_layer.hpp
index dde95dd..d982b65 100644
--- a/src/mlpack/methods/ann/layer/bias_layer.hpp
+++ b/src/mlpack/methods/ann/layer/bias_layer.hpp
@@ -110,8 +110,8 @@ class BiasLayer
}; // class BiasLayer
//! Layer traits for the bias layer.
-template<>
-class LayerTraits<BiasLayer<> >
+template<typename ActivationFunction, typename MatType, typename VecType>
+class LayerTraits<BiasLayer<ActivationFunction, MatType, VecType> >
{
public:
/**
@@ -120,6 +120,7 @@ class LayerTraits<BiasLayer<> >
static const bool IsBinary = false;
static const bool IsOutputLayer = false;
static const bool IsBiasLayer = true;
+ static const bool IsLSTMLayer = false;
};
}; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
index 109895d..ecd6064 100644
--- a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
@@ -66,16 +66,14 @@ class BinaryClassificationLayer
}; // class BinaryClassificationLayer
//! Layer traits for the binary class classification layer.
-template <
- typename MatType,
- typename VecType
->
+template <typename MatType, typename VecType>
class LayerTraits<BinaryClassificationLayer<MatType, VecType> >
{
public:
static const bool IsBinary = true;
static const bool IsOutputLayer = true;
static const bool IsBiasLayer = false;
+ static const bool IsLSTMLayer = false;
};
}; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/layer_traits.hpp b/src/mlpack/methods/ann/layer/layer_traits.hpp
index b414b05..52ee1af 100644
--- a/src/mlpack/methods/ann/layer/layer_traits.hpp
+++ b/src/mlpack/methods/ann/layer/layer_traits.hpp
@@ -36,6 +36,11 @@ class LayerTraits
* This is true if the layer is a bias layer.
*/
static const bool IsBiasLayer = false;
+
+ /*
+ * This is true if the layer is a LSTM layer.
+ **/
+ static const bool IsLSTMLayer = false;
};
}; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/lstm_layer.hpp b/src/mlpack/methods/ann/layer/lstm_layer.hpp
index 02d5c42..bc8dfaf 100644
--- a/src/mlpack/methods/ann/layer/lstm_layer.hpp
+++ b/src/mlpack/methods/ann/layer/lstm_layer.hpp
@@ -55,8 +55,8 @@ class LSTMLayer
* peephole connection matrix.
*/
LSTMLayer(const size_t layerSize,
- const size_t seqLen,
- const bool peepholes = true,
+ const size_t seqLen = 1,
+ const bool peepholes = false,
WeightInitRule weightInitRule = WeightInitRule()) :
inputActivations(arma::zeros<VecType>(layerSize * 4)),
layerSize(layerSize),
@@ -120,6 +120,22 @@ class LSTMLayer
*/
void FeedForward(const VecType& inputActivation, VecType& outputActivation)
{
+ if (inGate.n_cols < seqLen)
+ {
+ inGate = arma::zeros<MatType>(layerSize, seqLen);
+ inGateAct = arma::zeros<MatType>(layerSize, seqLen);
+ inGateError = arma::zeros<MatType>(layerSize, seqLen);
+ outGate = arma::zeros<MatType>(layerSize, seqLen);
+ outGateAct = arma::zeros<MatType>(layerSize, seqLen);
+ outGateError = arma::zeros<MatType>(layerSize, seqLen);
+ forgetGate = arma::zeros<MatType>(layerSize, seqLen);
+ forgetGateAct = arma::zeros<MatType>(layerSize, seqLen);
+ forgetGateError = arma::zeros<MatType>(layerSize, seqLen);
+ state = arma::zeros<MatType>(layerSize, seqLen);
+ stateError = arma::zeros<MatType>(layerSize, seqLen);
+ cellAct = arma::zeros<MatType>(layerSize, seqLen);
+ }
+
// Split up the inputactivation into the 3 parts (inGate, forgetGate,
// outGate).
inGate.col(offset) = inputActivation.subvec(0, layerSize - 1);
@@ -296,6 +312,12 @@ class LSTMLayer
VecType& Delta() const { return delta; }
// //! Modify the delta.
VecType& Delta() { return delta; }
+
+ //! Get the sequence length.
+ size_t SeqLen() const { return seqLen; }
+ //! Modify the sequence length.
+ size_t& SeqLen() { return seqLen; }
+
private:
//! Locally-stored input activation object.
VecType inputActivations;
@@ -388,6 +410,36 @@ class LSTMLayer
std::auto_ptr<OptimizerType> outGatePeepholeOptimizer;
}; // class LSTMLayer
+//! Layer traits for the bias layer.
+template<
+ class GateActivationFunction,
+ class StateActivationFunction,
+ class OutputActivationFunction,
+ class WeightInitRule,
+ typename OptimizerType,
+ typename MatType,
+ typename VecType
+>
+class LayerTraits<
+ LSTMLayer<GateActivationFunction,
+ StateActivationFunction,
+ OutputActivationFunction,
+ WeightInitRule,
+ OptimizerType,
+ MatType,
+ VecType>
+>
+{
+ public:
+ /**
+ * If true, then the layer is binary.
+ */
+ static const bool IsBinary = false;
+ static const bool IsOutputLayer = false;
+ static const bool IsBiasLayer = false;
+ static const bool IsLSTMLayer = true;
+};
+
}; // namespace ann
}; // namespace mlpack
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index a8dcc6d..7d4c16f 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -284,6 +284,9 @@ class RNN
typename std::enable_if<I < sizeof...(Tp), void>::type
Reset(std::tuple<Tp...>& t)
{
+ Parameter<I, typename std::remove_reference<
+ decltype(std::get<I>(t).InputLayer())>::type, Tp...>(t);
+
std::get<I>(t).OutputLayer().InputActivation().zeros(
std::get<I>(t).OutputLayer().InputSize());
@@ -301,6 +304,31 @@ class RNN
}
/**
+ * Update the sequence length for a specific layer.
+ *
+ * enable_if (SFINAE) is used to determine if classes passed contains the
+ * SeqLen function.
+ */
+ template<size_t I, typename LayerType, typename... Tp>
+ typename std::enable_if<
+ LayerTraits<LayerType>::IsLSTMLayer == false, void>::type
+ Parameter(std::tuple<Tp...>& /* unused */) { }
+
+ /**
+ * Update the sequence length for a specific layer.
+ *
+ * enable_if (SFINAE) is used to determine if classes passed contains the
+ * SeqLen function.
+ */
+ template<size_t I, typename LayerType, typename... Tp>
+ typename std::enable_if<
+ LayerTraits<LayerType>::IsLSTMLayer == true, void>::type
+ Parameter(std::tuple<Tp...>& t)
+ {
+ std::get<I>(t).InputLayer().SeqLen() = seqLen;
+ }
+
+ /**
* Run a single iteration of the feed forward algorithm, using the given
* input and target vector, updating the resulting error into the error
* vector.
@@ -448,7 +476,9 @@ class RNN
// Update the recurrent delta.
if (ConnectionTraits<typename std::remove_reference<decltype(
- std::get<I>(t))>::type>::IsSelfConnection)
+ std::get<I>(t))>::type>::IsSelfConnection ||
+ ConnectionTraits<typename std::remove_reference<decltype(
+ std::get<I>(t))>::type>::IsFullselfConnection)
{
std::get<I>(t).FeedBackward(delta[deltaNum]);
delta[deltaNum++] = std::get<I>(t).Delta();
@@ -464,7 +494,7 @@ class RNN
{
// Sum up the stored delta for recurrent connections.
if (recurrentLayer[layer])
- std::get<I>(t).Delta() += delta[deltaNum];
+ std::get<I>(t).Delta() += delta[deltaNum].subvec(0, std::get<I>(t).InputLayer().OutputSize() - 1);
// Perform the backward pass.
std::get<I>(t).InputLayer().FeedBackward(
More information about the mlpack-git
mailing list