[mlpack-git] master: Handle the identity connection. (8b2ca72)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Jul 12 09:24:00 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0...8b2ca720828224607c70d2b539c43aecf8f4ec32

>---------------------------------------------------------------

commit 8b2ca720828224607c70d2b539c43aecf8f4ec32
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Jul 11 20:20:48 2015 +0200

    Handle the identity connection.


>---------------------------------------------------------------

8b2ca720828224607c70d2b539c43aecf8f4ec32
 src/mlpack/methods/ann/ffnn.hpp | 22 +++++++++++++++++++---
 1 file changed, 19 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/ann/ffnn.hpp b/src/mlpack/methods/ann/ffnn.hpp
index 20a7dea..907280c 100644
--- a/src/mlpack/methods/ann/ffnn.hpp
+++ b/src/mlpack/methods/ann/ffnn.hpp
@@ -64,6 +64,7 @@ class FFNN
                      const VecType& target,
                      VecType& error)
     {
+      deterministic = false;
       seqNum++;
       trainError += Evaluate(input, target, error);
     }
@@ -106,6 +107,7 @@ class FFNN
     template <typename VecType>
     void Predict(const VecType& input, VecType& output)
     {
+      deterministic = true;
       ResetActivations(network);
 
       std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -126,6 +128,7 @@ class FFNN
     template <typename VecType>
     double Evaluate(const VecType& input, const VecType& target, VecType& error)
     {
+      deterministic = false;
       ResetActivations(network);
 
       std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -172,6 +175,7 @@ class FFNN
     typename std::enable_if<I < sizeof...(Tp), void>::type
     Reset(std::tuple<Tp...>& t)
     {
+      std::get<I>(t).OutputLayer().Deterministic() = deterministic;
       std::get<I>(t).OutputLayer().InputActivation().zeros();
       Reset<I + 1, Tp...>(t);
     }
@@ -360,7 +364,12 @@ class FFNN
     typename std::enable_if<I < sizeof...(Tp), void>::type
     Gradients(std::tuple<Tp...>& t)
     {
-      std::get<I>(t).Optimzer().Update();
+      if (!ConnectionTraits<typename std::remove_reference<decltype(
+          std::get<I>(t))>::type>::IsIdentityConnection)
+      {
+        std::get<I>(t).Optimzer().Update();
+      }
+
       Gradients<I + 1, Tp...>(t);
     }
 
@@ -400,8 +409,12 @@ class FFNN
     typename std::enable_if<I < sizeof...(Tp), void>::type
     Apply(std::tuple<Tp...>& t)
     {
-      std::get<I>(t).Optimzer().Optimize();
-      std::get<I>(t).Optimzer().Reset();
+      if (!ConnectionTraits<typename std::remove_reference<decltype(
+          std::get<I>(t))>::type>::IsIdentityConnection)
+      {
+        std::get<I>(t).Optimzer().Optimize();
+        std::get<I>(t).Optimzer().Reset();
+      }
 
       Apply<I + 1, Tp...>(t);
     }
@@ -417,6 +430,9 @@ class FFNN
 
     //! The number of the current input sequence.
     size_t seqNum;
+
+    //! The current evaluation mode (training or testing).
+    bool deterministic;
 }; // class FFNN
 
 



More information about the mlpack-git mailing list