[mlpack-git] master: Handle the identity connection. (8b2ca72)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Jul 12 09:24:00 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0...8b2ca720828224607c70d2b539c43aecf8f4ec32
>---------------------------------------------------------------
commit 8b2ca720828224607c70d2b539c43aecf8f4ec32
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Jul 11 20:20:48 2015 +0200
Handle the identity connection.
>---------------------------------------------------------------
8b2ca720828224607c70d2b539c43aecf8f4ec32
src/mlpack/methods/ann/ffnn.hpp | 22 +++++++++++++++++++---
1 file changed, 19 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/ann/ffnn.hpp b/src/mlpack/methods/ann/ffnn.hpp
index 20a7dea..907280c 100644
--- a/src/mlpack/methods/ann/ffnn.hpp
+++ b/src/mlpack/methods/ann/ffnn.hpp
@@ -64,6 +64,7 @@ class FFNN
const VecType& target,
VecType& error)
{
+ deterministic = false;
seqNum++;
trainError += Evaluate(input, target, error);
}
@@ -106,6 +107,7 @@ class FFNN
template <typename VecType>
void Predict(const VecType& input, VecType& output)
{
+ deterministic = true;
ResetActivations(network);
std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -126,6 +128,7 @@ class FFNN
template <typename VecType>
double Evaluate(const VecType& input, const VecType& target, VecType& error)
{
+ deterministic = false;
ResetActivations(network);
std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -172,6 +175,7 @@ class FFNN
typename std::enable_if<I < sizeof...(Tp), void>::type
Reset(std::tuple<Tp...>& t)
{
+ std::get<I>(t).OutputLayer().Deterministic() = deterministic;
std::get<I>(t).OutputLayer().InputActivation().zeros();
Reset<I + 1, Tp...>(t);
}
@@ -360,7 +364,12 @@ class FFNN
typename std::enable_if<I < sizeof...(Tp), void>::type
Gradients(std::tuple<Tp...>& t)
{
- std::get<I>(t).Optimzer().Update();
+ if (!ConnectionTraits<typename std::remove_reference<decltype(
+ std::get<I>(t))>::type>::IsIdentityConnection)
+ {
+ std::get<I>(t).Optimzer().Update();
+ }
+
Gradients<I + 1, Tp...>(t);
}
@@ -400,8 +409,12 @@ class FFNN
typename std::enable_if<I < sizeof...(Tp), void>::type
Apply(std::tuple<Tp...>& t)
{
- std::get<I>(t).Optimzer().Optimize();
- std::get<I>(t).Optimzer().Reset();
+ if (!ConnectionTraits<typename std::remove_reference<decltype(
+ std::get<I>(t))>::type>::IsIdentityConnection)
+ {
+ std::get<I>(t).Optimzer().Optimize();
+ std::get<I>(t).Optimzer().Reset();
+ }
Apply<I + 1, Tp...>(t);
}
@@ -417,6 +430,9 @@ class FFNN
//! The number of the current input sequence.
size_t seqNum;
+
+ //! The current evaluation mode (training or testing).
+ bool deterministic;
}; // class FFNN
More information about the mlpack-git
mailing list