[mlpack-git] master: Make sure, the gradient isn't empty, before propagating the gradient parameter through the network. (0e8d776)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Apr 9 07:36:50 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f4b3464fce6bdc7c61d94f6b22bc71fe61276328...0e8d776e03b8dbe8e605063b388115cb22b1860d
>---------------------------------------------------------------
commit 0e8d776e03b8dbe8e605063b388115cb22b1860d
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Apr 9 13:36:50 2016 +0200
Make sure, the gradient isn't empty, before propagating the gradient parameter through the network.
>---------------------------------------------------------------
0e8d776e03b8dbe8e605063b388115cb22b1860d
src/mlpack/methods/ann/ffn_impl.hpp | 6 ++++++
src/mlpack/methods/ann/rnn_impl.hpp | 10 +++++++++-
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/ffn_impl.hpp
index ca48c30..f9f3c30 100644
--- a/src/mlpack/methods/ann/ffn_impl.hpp
+++ b/src/mlpack/methods/ann/ffn_impl.hpp
@@ -252,6 +252,12 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
const size_t i,
arma::mat& gradient)
{
+ if (gradient.is_empty())
+ {
+ gradient = arma::zeros<arma::mat>(parameter.n_rows, parameter.n_cols);
+ }
+
+
Evaluate(parameter, i, false);
NetworkGradients(gradient, network);
diff --git a/src/mlpack/methods/ann/rnn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp
index f5b519a..7c04eb7 100644
--- a/src/mlpack/methods/ann/rnn_impl.hpp
+++ b/src/mlpack/methods/ann/rnn_impl.hpp
@@ -281,9 +281,17 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
const size_t i,
arma::mat& gradient)
{
+ if (gradient.is_empty())
+ {
+ gradient = arma::zeros<arma::mat>(parameter.n_rows, parameter.n_cols);
+ }
+ else
+ {
+ gradient.zeros();
+ }
+
Evaluate(parameter, i, false);
- gradient.zeros();
arma::mat currentGradient = arma::mat(gradient.n_rows, gradient.n_cols);
NetworkGradients(currentGradient, network);
More information about the mlpack-git
mailing list