[mlpack-git] master: Set the current evaluation mode. (d712369)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Jul 5 08:53:34 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/d9e984e1c608679171ad52e8522916703c7b331f...267bf1f0ace881bea4a38bf1156cc9f503930f09
>---------------------------------------------------------------
commit d712369dc82432f3ffc20443cb3d3a8efbc0b934
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sun Jul 5 14:44:24 2015 +0200
Set the current evaluation mode.
>---------------------------------------------------------------
d712369dc82432f3ffc20443cb3d3a8efbc0b934
src/mlpack/methods/ann/cnn.hpp | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp
index bcd08a8..b5700d1 100644
--- a/src/mlpack/methods/ann/cnn.hpp
+++ b/src/mlpack/methods/ann/cnn.hpp
@@ -65,6 +65,7 @@ class CNN
const OutputType& target,
ErrorType& error)
{
+ deterministic = false;
seqNum++;
trainError += Evaluate(input, target, error);
}
@@ -106,6 +107,7 @@ class CNN
template <typename InputType, typename OutputType>
void Predict(const InputType& input, OutputType& output)
{
+ deterministic = true;
ResetActivations(network);
std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -128,6 +130,7 @@ class CNN
const OutputType& target,
ErrorType& error)
{
+ deterministic = false;
ResetActivations(network);
std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -174,6 +177,7 @@ class CNN
typename std::enable_if<I < sizeof...(Tp), void>::type
Reset(std::tuple<Tp...>& t)
{
+ std::get<I>(t).OutputLayer().Deterministic() = deterministic;
std::get<I>(t).OutputLayer().InputActivation().zeros();
std::get<I>(t).Delta().zeros();
Reset<I + 1, Tp...>(t);
@@ -455,6 +459,9 @@ class CNN
//! The number of the current input sequence.
size_t seqNum;
+
+ //! The current evaluation mode (training or testing).
+ bool deterministic;
}; // class CNN
More information about the mlpack-git
mailing list