[mlpack-git] master: Make sure, the gradient isn't empty, before propagating the gradient parameter through the network. (0e8d776)

gitdub at mlpack.org gitdub at mlpack.org
Sat Apr 9 07:36:50 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f4b3464fce6bdc7c61d94f6b22bc71fe61276328...0e8d776e03b8dbe8e605063b388115cb22b1860d

>---------------------------------------------------------------

commit 0e8d776e03b8dbe8e605063b388115cb22b1860d
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Apr 9 13:36:50 2016 +0200

    Make sure, the gradient isn't empty, before propagating the gradient parameter through the network.


>---------------------------------------------------------------

0e8d776e03b8dbe8e605063b388115cb22b1860d
 src/mlpack/methods/ann/ffn_impl.hpp |  6 ++++++
 src/mlpack/methods/ann/rnn_impl.hpp | 10 +++++++++-
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/ffn_impl.hpp
index ca48c30..f9f3c30 100644
--- a/src/mlpack/methods/ann/ffn_impl.hpp
+++ b/src/mlpack/methods/ann/ffn_impl.hpp
@@ -252,6 +252,12 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
             const size_t i,
             arma::mat& gradient)
 {
+  if (gradient.is_empty())
+  {
+    gradient = arma::zeros<arma::mat>(parameter.n_rows, parameter.n_cols);
+  }
+
+
   Evaluate(parameter, i, false);
 
   NetworkGradients(gradient, network);
diff --git a/src/mlpack/methods/ann/rnn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp
index f5b519a..7c04eb7 100644
--- a/src/mlpack/methods/ann/rnn_impl.hpp
+++ b/src/mlpack/methods/ann/rnn_impl.hpp
@@ -281,9 +281,17 @@ LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
             const size_t i,
             arma::mat& gradient)
 {
+  if (gradient.is_empty())
+  {
+    gradient = arma::zeros<arma::mat>(parameter.n_rows, parameter.n_cols);
+  }
+  else
+  {
+    gradient.zeros();
+  }
+
   Evaluate(parameter, i, false);
 
-  gradient.zeros();
   arma::mat currentGradient = arma::mat(gradient.n_rows, gradient.n_cols);
   NetworkGradients(currentGradient, network);
 




More information about the mlpack-git mailing list