[mlpack-git] master: Refactor LSTM layer for new network API. (d77572e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:57 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3
>---------------------------------------------------------------
commit d77572e5845a3643cc54b6c1a94efad18e144d86
Author: marcus <marcus.edel at fu-berlin.de>
Date: Mon Nov 9 17:42:15 2015 +0100
Refactor LSTM layer for new network API.
>---------------------------------------------------------------
d77572e5845a3643cc54b6c1a94efad18e144d86
src/mlpack/methods/ann/layer/lstm_layer.hpp | 449 ++++++++++++----------------
1 file changed, 194 insertions(+), 255 deletions(-)
diff --git a/src/mlpack/methods/ann/layer/lstm_layer.hpp b/src/mlpack/methods/ann/layer/lstm_layer.hpp
index b452ddc..327976b 100644
--- a/src/mlpack/methods/ann/layer/lstm_layer.hpp
+++ b/src/mlpack/methods/ann/layer/lstm_layer.hpp
@@ -10,10 +10,8 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/layer/layer_traits.hpp>
-#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
-#include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
-#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/optimizer/irpropp.hpp>
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
@@ -25,21 +23,27 @@ namespace ann /** Artificial Neural Network. */ {
* for the gates and cells and also of the type of the function used to
* initialize and update the peephole weights.
*
+ * @tparam OptimizerType Type of the optimizer used to update the weights.
* @tparam GateActivationFunction Activation function used for the gates.
* @tparam StateActivationFunction Activation function used for the state.
* @tparam OutputActivationFunction Activation function used for the output.
* @tparam WeightInitRule Rule used to initialize the weight matrix.
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+ * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
+ * arma::sp_mat or arma::cube).
+ * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
+ * arma::sp_mat or arma::cube).
+ * @tparam PeepholeDataType Type of the peephole data (weights, derivatives and
+ * gradients).
*/
template <
+ template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
class GateActivationFunction = LogisticFunction,
class StateActivationFunction = TanhFunction,
class OutputActivationFunction = TanhFunction,
- class WeightInitRule = NguyenWidrowInitialization,
- typename OptimizerType = iRPROPp<>,
- typename MatType = arma::mat,
- typename VecType = arma::colvec
+ class WeightInitRule = RandomInitialization,
+ typename InputDataType = arma::mat,
+ typename OutputDataType = arma::mat,
+ typename PeepholeDataType = arma::cube
>
class LSTMLayer
{
@@ -47,117 +51,96 @@ class LSTMLayer
/**
* Create the LSTMLayer object using the specified parameters.
*
- * @param layerSize The number of memory cells.
- * @param layerSize The length of the input sequence.
+ * @param outSize The number of output units.
* @param peepholes The flag used to indicate if peephole connections should
- * be used (Default: true).
- * @param WeightInitRule The weight initialize rule used to initialize the
- * peephole connection matrix.
+ * be used (Default: false).
+ * @param WeightInitRule The weight initialization rule used to initialize the
+ * weight matrix.
*/
- LSTMLayer(const size_t layerSize,
- const size_t seqLen = 1,
+ LSTMLayer(const size_t outSize,
const bool peepholes = false,
WeightInitRule weightInitRule = WeightInitRule()) :
- inputActivations(arma::zeros<VecType>(layerSize * 4)),
- layerSize(layerSize),
- seqLen(seqLen),
- inGate(arma::zeros<MatType>(layerSize, seqLen)),
- inGateAct(arma::zeros<MatType>(layerSize, seqLen)),
- inGateError(arma::zeros<MatType>(layerSize, seqLen)),
- outGate(arma::zeros<MatType>(layerSize, seqLen)),
- outGateAct(arma::zeros<MatType>(layerSize, seqLen)),
- outGateError(arma::zeros<MatType>(layerSize, seqLen)),
- forgetGate(arma::zeros<MatType>(layerSize, seqLen)),
- forgetGateAct(arma::zeros<MatType>(layerSize, seqLen)),
- forgetGateError(arma::zeros<MatType>(layerSize, seqLen)),
- state(arma::zeros<MatType>(layerSize, seqLen)),
- stateError(arma::zeros<MatType>(layerSize, seqLen)),
- cellAct(arma::zeros<MatType>(layerSize, seqLen)),
+ outSize(outSize),
+ peepholes(peepholes),
+ seqLen(1),
offset(0),
- peepholes(peepholes)
+ optimizer(new OptimizerType<LSTMLayer<OptimizerType,
+ GateActivationFunction,
+ StateActivationFunction,
+ OutputActivationFunction,
+ WeightInitRule,
+ InputDataType,
+ OutputDataType,
+ PeepholeDataType>,
+ PeepholeDataType>(*this)),
+ ownsOptimizer(true)
{
if (peepholes)
{
- weightInitRule.Initialize(inGatePeepholeWeights, layerSize, 1);
- inGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
- inGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
- new OptimizerType(1, layerSize));
-
- weightInitRule.Initialize(forgetGatePeepholeWeights, layerSize, 1);
- forgetGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
- forgetGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
- new OptimizerType(1, layerSize));
-
- weightInitRule.Initialize(outGatePeepholeWeights, layerSize, 1);
- outGatePeepholeDerivatives = arma::zeros<VecType>(layerSize);
- outGatePeepholeOptimizer = std::unique_ptr<OptimizerType>(
- new OptimizerType(1, layerSize));
+ weightInitRule.Initialize(peepholeWeights, outSize, 1, 3);
+ peepholeDerivatives = PeepholeDataType(outSize, 1, 3);
+ peepholeGradient = PeepholeDataType(outSize, 1, 3);
}
}
+ /**
+ * Delete the LSTMLayer object and its optimizer.
+ */
~LSTMLayer()
{
- OptimizerType* inGatePeepholePtr = inGatePeepholeOptimizer.release();
- delete inGatePeepholePtr;
-
- OptimizerType* forgetGatePeepholePtr = forgetGatePeepholeOptimizer.release();
- delete forgetGatePeepholePtr;
-
- OptimizerType* outGatePeepholePtr = outGatePeepholeOptimizer.release();
- delete outGatePeepholePtr;
+ if (ownsOptimizer)
+ delete optimizer;
}
/**
* Ordinary feed forward pass of a neural network, evaluating the function
* f(x) by propagating the activity forward through f.
*
- * @param inputActivation Input data used for evaluating the specified
- * activity function.
- * @param outputActivation Datatype to store the resulting output activation.
+ * @param input Input data used for evaluating the specified function.
+ * @param output Resulting output activation.
*/
- void FeedForward(const VecType& inputActivation, VecType& outputActivation)
+ template<typename eT>
+ void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output)
{
if (inGate.n_cols < seqLen)
{
- inGate = arma::zeros<MatType>(layerSize, seqLen);
- inGateAct = arma::zeros<MatType>(layerSize, seqLen);
- inGateError = arma::zeros<MatType>(layerSize, seqLen);
- outGate = arma::zeros<MatType>(layerSize, seqLen);
- outGateAct = arma::zeros<MatType>(layerSize, seqLen);
- outGateError = arma::zeros<MatType>(layerSize, seqLen);
- forgetGate = arma::zeros<MatType>(layerSize, seqLen);
- forgetGateAct = arma::zeros<MatType>(layerSize, seqLen);
- forgetGateError = arma::zeros<MatType>(layerSize, seqLen);
- state = arma::zeros<MatType>(layerSize, seqLen);
- stateError = arma::zeros<MatType>(layerSize, seqLen);
- cellAct = arma::zeros<MatType>(layerSize, seqLen);
+ inGate = arma::zeros<InputDataType>(outSize, seqLen);
+ inGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+ inGateError = arma::zeros<InputDataType>(outSize, seqLen);
+ outGate = arma::zeros<InputDataType>(outSize, seqLen);
+ outGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+ outGateError = arma::zeros<InputDataType>(outSize, seqLen);
+ forgetGate = arma::zeros<InputDataType>(outSize, seqLen);
+ forgetGateAct = arma::zeros<InputDataType>(outSize, seqLen);
+ forgetGateError = arma::zeros<InputDataType>(outSize, seqLen);
+ state = arma::zeros<InputDataType>(outSize, seqLen);
+ stateError = arma::zeros<InputDataType>(outSize, seqLen);
+ cellAct = arma::zeros<InputDataType>(outSize, seqLen);
}
// Split up the inputactivation into the 3 parts (inGate, forgetGate,
// outGate).
- inGate.col(offset) = inputActivation.subvec(0, layerSize - 1);
- forgetGate.col(offset) = inputActivation.subvec(
- layerSize, (layerSize * 2) - 1);
- outGate.col(offset) = inputActivation.subvec(
- layerSize * 3, (layerSize * 4) - 1);
+ inGate.col(offset) = input.submat(0, 0, outSize - 1, 0);
+ forgetGate.col(offset) = input.submat(outSize, 0, (outSize * 2) - 1, 0);
+ outGate.col(offset) = input.submat(outSize * 3, 0, (outSize * 4) - 1, 0);
if (peepholes && offset > 0)
{
- inGate.col(offset) += inGatePeepholeWeights % state.col(offset - 1);
- forgetGate.col(offset) += forgetGatePeepholeWeights %
+ inGate.col(offset) += peepholeWeights.slice(0) % state.col(offset - 1);
+ forgetGate.col(offset) += peepholeWeights.slice(1) %
state.col(offset - 1);
}
- VecType inGateActivation = inGateAct.unsafe_col(offset);
+ arma::Col<eT> inGateActivation = inGateAct.unsafe_col(offset);
GateActivationFunction::fn(inGate.unsafe_col(offset), inGateActivation);
- VecType forgetGateActivation = forgetGateAct.unsafe_col(offset);
+ arma::Col<eT> forgetGateActivation = forgetGateAct.unsafe_col(offset);
GateActivationFunction::fn(forgetGate.unsafe_col(offset),
forgetGateActivation);
- VecType cellActivation = cellAct.unsafe_col(offset);
- StateActivationFunction::fn(inputActivation.subvec(layerSize * 2,
- (layerSize * 3) - 1), cellActivation);
+ arma::Col<eT> cellActivation = cellAct.unsafe_col(offset);
+ StateActivationFunction::fn(input.submat(outSize * 2, 0,
+ (outSize * 3) - 1, 0), cellActivation);
state.col(offset) = inGateAct.col(offset) % cellActivation;
@@ -165,13 +148,13 @@ class LSTMLayer
state.col(offset) += forgetGateAct.col(offset) % state.col(offset - 1);
if (peepholes)
- outGate.col(offset) += outGatePeepholeWeights % state.col(offset);
+ outGate.col(offset) += peepholeWeights.slice(2) % state.col(offset);
- VecType outGateActivation = outGateAct.unsafe_col(offset);
+ arma::Col<eT> outGateActivation = outGateAct.unsafe_col(offset);
GateActivationFunction::fn(outGate.unsafe_col(offset), outGateActivation);
- OutputActivationFunction::fn(state.unsafe_col(offset), outputActivation);
- outputActivation = outGateAct.col(offset) % outputActivation;
+ OutputActivationFunction::fn(state.unsafe_col(offset), output);
+ output = outGateAct.col(offset) % output;
offset = (offset + 1) % seqLen;
}
@@ -181,30 +164,30 @@ class LSTMLayer
* f(x) by propagating x backwards trough f. Using the results from the feed
* forward pass.
*
- * @param inputActivation Input data used for calculating the function f(x).
- * @param error The backpropagated error.
- * @param delta The calculating delta using the partial derivative of the
- * error with respect to a weight.
+ * @param input The propagated input activation.
+ * @param gy The backpropagated error.
+ * @param g The calculated gradient.
*/
- void FeedBackward(const VecType& /* unused */,
- const VecType& error,
- VecType& delta)
+ template<typename InputType, typename eT>
+ void Backward(const InputType& /* unused */,
+ const arma::Mat<eT>& gy,
+ arma::Mat<eT>& g)
{
size_t queryOffset = seqLen - offset - 1;
- VecType outGateDerivative;
+ arma::Col<eT> outGateDerivative;
GateActivationFunction::deriv(outGateAct.unsafe_col(queryOffset),
outGateDerivative);
- VecType stateActivation;
+ arma::Col<eT> stateActivation;
StateActivationFunction::fn(state.unsafe_col(queryOffset), stateActivation);
- outGateError.col(queryOffset) = outGateDerivative % error % stateActivation;
+ outGateError.col(queryOffset) = outGateDerivative % gy % stateActivation;
- VecType stateDerivative;
+ arma::Col<eT> stateDerivative;
StateActivationFunction::deriv(stateActivation, stateDerivative);
- stateError.col(queryOffset) = error % outGateAct.col(queryOffset) %
+ stateError.col(queryOffset) = gy % outGateAct.col(queryOffset) %
stateDerivative;
if (queryOffset < (seqLen - 1))
@@ -215,27 +198,27 @@ class LSTMLayer
if (peepholes)
{
stateError.col(queryOffset) += inGateError.col(queryOffset + 1) %
- inGatePeepholeWeights;
+ peepholeWeights.slice(0);
stateError.col(queryOffset) += forgetGateError.col(queryOffset + 1) %
- forgetGatePeepholeWeights;
+ peepholeWeights.slice(1);
}
}
if (peepholes)
{
stateError.col(queryOffset) += outGateError.col(queryOffset) %
- outGatePeepholeWeights;
+ peepholeWeights.slice(2);
}
- VecType cellDerivative;
+ arma::Col<eT> cellDerivative;
StateActivationFunction::deriv(cellAct.col(queryOffset), cellDerivative);
- VecType cellError = inGateAct.col(queryOffset) % cellDerivative %
+ arma::Col<eT> cellError = inGateAct.col(queryOffset) % cellDerivative %
stateError.col(queryOffset);
if (queryOffset > 0)
{
- VecType forgetGateDerivative;
+ arma::Col<eT> forgetGateDerivative;
GateActivationFunction::deriv(forgetGateAct.col(queryOffset),
forgetGateDerivative);
@@ -243,7 +226,7 @@ class LSTMLayer
stateError.col(queryOffset) % state.col(queryOffset - 1);
}
- VecType inGateDerivative;
+ arma::Col<eT> inGateDerivative;
GateActivationFunction::deriv(inGateAct.col(queryOffset), inGateDerivative);
inGateError.col(queryOffset) = inGateDerivative %
@@ -251,232 +234,188 @@ class LSTMLayer
if (peepholes)
{
- outGateDerivative += outGateError.col(queryOffset) %
+ peepholeDerivatives.slice(2) += outGateError.col(queryOffset) %
state.col(queryOffset);
+
if (queryOffset > 0)
{
- inGatePeepholeDerivatives += inGateError.col(queryOffset) %
+ peepholeDerivatives.slice(0) += inGateError.col(queryOffset) %
state.col(queryOffset - 1);
- forgetGatePeepholeDerivatives += forgetGateError.col(queryOffset) %
+ peepholeDerivatives.slice(1) += forgetGateError.col(queryOffset) %
state.col(queryOffset - 1);
}
}
- delta = arma::zeros<VecType>(layerSize * 4);
- delta.subvec(0, layerSize - 1) = inGateError.col(queryOffset);
- delta.subvec(layerSize, (layerSize * 2) - 1) =
+ g = arma::zeros<arma::Mat<eT> >(outSize * 4, 1);
+ g.submat(0, 0, outSize - 1, 0) = inGateError.col(queryOffset);
+ g.submat(outSize, 0, (outSize * 2) - 1, 0) =
forgetGateError.col(queryOffset);
- delta.subvec(layerSize * 2, (layerSize * 3) - 1) = cellError;
- delta.subvec(layerSize * 3, (layerSize * 4) - 1) =
+ g.submat(outSize * 2, 0, (outSize * 3) - 1, 0) = cellError;
+ g.submat(outSize * 3, 0, (outSize * 4) - 1, 0) =
outGateError.col(queryOffset);
offset = (offset + 1) % seqLen;
if (peepholes && offset == 0)
{
- inGatePeepholeGradient = (inGatePeepholeWeights.t() *
- (inGateError.col(queryOffset) % inGatePeepholeDerivatives)) *
- inGate.col(queryOffset).t();
-
- forgetGatePeepholeGradient = (forgetGatePeepholeWeights.t() *
- (forgetGateError.col(queryOffset) % forgetGatePeepholeDerivatives)) *
- forgetGate.col(queryOffset).t();
-
- outGatePeepholeGradient = (outGatePeepholeWeights.t() *
- (outGateError.col(queryOffset) % outGatePeepholeDerivatives)) *
- outGate.col(queryOffset).t();
-
- inGatePeepholeOptimizer->UpdateWeights(inGatePeepholeWeights,
- inGatePeepholeGradient.t(), 0);
-
- forgetGatePeepholeOptimizer->UpdateWeights(forgetGatePeepholeWeights,
- forgetGatePeepholeGradient.t(), 0);
-
- outGatePeepholeOptimizer->UpdateWeights(outGatePeepholeWeights,
- outGatePeepholeGradient.t(), 0);
-
- inGatePeepholeDerivatives.zeros();
- forgetGatePeepholeDerivatives.zeros();
- outGatePeepholeDerivatives.zeros();
+ peepholeGradient.slice(0) = arma::trans((peepholeWeights.slice(0).t() *
+ (inGateError.col(queryOffset) % peepholeDerivatives.slice(0))) *
+ inGate.col(queryOffset).t());
+
+ peepholeGradient.slice(1) = arma::trans((peepholeWeights.slice(1).t() *
+ (forgetGateError.col(queryOffset) % peepholeDerivatives.slice(1))) *
+ forgetGate.col(queryOffset).t());
+
+ peepholeGradient.slice(2) = arma::trans((peepholeWeights.slice(2).t() *
+ (outGateError.col(queryOffset) % peepholeDerivatives.slice(2))) *
+ outGate.col(queryOffset).t());
+
+ optimizer->Update();
+ optimizer->Optimize();
+ optimizer->Reset();
+ peepholeDerivatives.zeros();
}
}
- //! Get the input activations.
- const VecType& InputActivation() const { return inputActivations; }
- //! Modify the input activations.
- VecType& InputActivation() { return inputActivations; }
-
- //! Get input size.
- size_t InputSize() const { return layerSize * 4; }
+ //! Get the peephole weights.
+ PeepholeDataType& Weights() const { return peepholeWeights; }
+ //! Modify the peephole weights.
+ PeepholeDataType& Weights() { return peepholeWeights; }
- //! Get output size.
- size_t OutputSize() const { return layerSize; }
- //! Modify the output size.
- size_t& OutputSize() { return layerSize; }
+ //! Get the input parameter.
+ InputDataType& InputParameter() const {return inputParameter; }
+ //! Modify the input parameter.
+ InputDataType& InputParameter() { return inputParameter; }
- //! Get the number of output maps.
- size_t OutputMaps() const { return 1; }
+ //! Get the output parameter.
+ OutputDataType& OutputParameter() const {return outputParameter; }
+ //! Modify the output parameter.
+ OutputDataType& OutputParameter() { return outputParameter; }
- //! Get the number of layer slices.
- size_t LayerSlices() const { return 1; }
-
- //! Get the number of layer rows.
- size_t LayerRows() const { return layerSize; }
-
- //! Get the number of layer columns.
- size_t LayerCols() const { return 1; }
-
- //! Get the detla.
- VecType& Delta() const { return delta; }
+ //! Get the delta.
+ OutputDataType& Delta() const {return delta; }
//! Modify the delta.
- VecType& Delta() { return delta; }
+ OutputDataType& Delta() { return delta; }
+
+ //! Get the peephole gradient.
+ PeepholeDataType& Gradient() const {return peepholeGradient; }
+ //! Modify the peephole gradient.
+ PeepholeDataType& Gradient() { return peepholeGradient; }
//! Get the sequence length.
size_t SeqLen() const { return seqLen; }
//! Modify the sequence length.
size_t& SeqLen() { return seqLen; }
- //! Get the InGate peephole weights..
- MatType& InGatePeepholeWeights() const { return inGatePeepholeWeights; }
- //! Modify the InGate peephole weights..
- MatType& InGatePeepholeWeights() { return inGatePeepholeWeights; }
+ private:
+ //! Locally-stored number of output units.
+ const size_t outSize;
- //! Get the InGate peephole weights..
- MatType& ForgetGatePeepholeWeights() const {
- return forgetGatePeepholeWeights; }
- //! Modify the InGate peephole weights..
- MatType& ForgetGatePeepholeWeights() { return forgetGatePeepholeWeights; }
+ //! Locally-stored peephole indication flag.
+ const bool peepholes;
- //! Get the InGate peephole weights..
- MatType& OutGatePeepholeWeights() const { return outGatePeepholeWeights; }
- //! Modify the InGate peephole weights..
- MatType& OutGatePeepholeWeights() { return outGatePeepholeWeights; }
+ //! Locally-stored length of the the input sequence.
+ size_t seqLen;
- //! The the value of the deterministic parameter.
- bool Deterministic() const {return deterministic; }
- //! Modify the value of the deterministic parameter.
- bool& Deterministic() {return deterministic; }
+ //! Locally-stored sequence offset.
+ size_t offset;
- private:
- //! Locally-stored input activation object.
- VecType inputActivations;
+ //! Locally-stored pointer to the optimzer object.
+ OptimizerType<LSTMLayer<OptimizerType,
+ GateActivationFunction,
+ StateActivationFunction,
+ OutputActivationFunction,
+ WeightInitRule,
+ InputDataType,
+ OutputDataType,
+ PeepholeDataType>, PeepholeDataType>* optimizer;
+
+ //! Parameter that indicates if the class owns a optimizer object.
+ bool ownsOptimizer;
//! Locally-stored delta object.
- VecType delta;
+ OutputDataType delta;
- //! Locally-stored number of memory cells.
- size_t layerSize;
+ //! Locally-stored gradient object.
+ OutputDataType gradient;
- //! Locally-stored length of the the input sequence.
- size_t seqLen;
+ //! Locally-stored input parameter object.
+ InputDataType inputParameter;
+
+ //! Locally-stored output parameter object.
+ OutputDataType outputParameter;
//! Locally-stored ingate object.
- MatType inGate;
+ InputDataType inGate;
//! Locally-stored ingate activation object.
- MatType inGateAct;
+ InputDataType inGateAct;
//! Locally-stored ingate error object.
- MatType inGateError;
+ InputDataType inGateError;
//! Locally-stored outgate object.
- MatType outGate;
+ InputDataType outGate;
//! Locally-stored outgate activation object.
- MatType outGateAct;
+ InputDataType outGateAct;
//! Locally-stored outgate error object.
- MatType outGateError;
+ InputDataType outGateError;
//! Locally-stored forget object.
- MatType forgetGate;
+ InputDataType forgetGate;
//! Locally-stored forget activation object.
- MatType forgetGateAct;
+ InputDataType forgetGateAct;
//! Locally-stored forget error object.
- MatType forgetGateError;
+ InputDataType forgetGateError;
//! Locally-stored state object.
- MatType state;
+ InputDataType state;
//! Locally-stored state erro object.
- MatType stateError;
+ InputDataType stateError;
//! Locally-stored cell activation object.
- MatType cellAct;
-
- //! Locally-stored sequence offset.
- size_t offset;
-
- //! Locally-stored peephole indication flag.
- const bool peepholes;
+ InputDataType cellAct;
- //! Locally-stored peephole ingate weights.
- MatType inGatePeepholeWeights;
+ //! Locally-stored peephole weight object.
+ PeepholeDataType peepholeWeights;
- //! Locally-stored peephole ingate derivatives.
- VecType inGatePeepholeDerivatives;
+ //! Locally-stored derivatives object.
+ PeepholeDataType peepholeDerivatives;
- //! Locally-stored peephole ingate gradients.
- MatType inGatePeepholeGradient;
-
- //! Locally-stored ingate peephole optimzer object.
- std::unique_ptr<OptimizerType> inGatePeepholeOptimizer;
-
- //! Locally-stored peephole forget weights.
- MatType forgetGatePeepholeWeights;
-
- //! Locally-stored peephole forget derivatives.
- VecType forgetGatePeepholeDerivatives;
-
- //! Locally-stored peephole forget gradients.
- MatType forgetGatePeepholeGradient;
-
- //! Locally-stored forget peephole optimzer object.
- std::unique_ptr<OptimizerType> forgetGatePeepholeOptimizer;
-
- //! Locally-stored peephole outgate weights.
- MatType outGatePeepholeWeights;
-
- //! Locally-stored peephole outgate derivatives.
- VecType outGatePeepholeDerivatives;
-
- //! Locally-stored peephole outgate gradients.
- MatType outGatePeepholeGradient;
-
- //! Locally-stored outgate peephole optimzer object.
- std::unique_ptr<OptimizerType> outGatePeepholeOptimizer;
-
- //! Locally-stored deterministic parameter.
- bool deterministic;
+ //! Locally-stored peephole gradient object.
+ PeepholeDataType peepholeGradient;
}; // class LSTMLayer
-//! Layer traits for the bias layer.
+//! Layer traits for the lstm layer.
template<
+ template<typename, typename> class OptimizerType,
class GateActivationFunction,
class StateActivationFunction,
class OutputActivationFunction,
class WeightInitRule,
- typename OptimizerType,
- typename MatType,
- typename VecType
->
-class LayerTraits<
- LSTMLayer<GateActivationFunction,
- StateActivationFunction,
- OutputActivationFunction,
- WeightInitRule,
- OptimizerType,
- MatType,
- VecType>
+ typename InputDataType,
+ typename OutputDataType,
+ typename PeepholeDataType
>
+class LayerTraits<LSTMLayer<OptimizerType,
+ GateActivationFunction,
+ StateActivationFunction,
+ WeightInitRule,
+ InputDataType,
+ OutputDataType,
+ PeepholeDataType> >
{
public:
static const bool IsBinary = false;
static const bool IsOutputLayer = false;
static const bool IsBiasLayer = false;
static const bool IsLSTMLayer = true;
+ static const bool IsConnection = false;
};
}; // namespace ann
More information about the mlpack-git
mailing list