[mlpack-git] master: Set the current evaluation mode. (d712369)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Jul 5 08:53:34 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/d9e984e1c608679171ad52e8522916703c7b331f...267bf1f0ace881bea4a38bf1156cc9f503930f09

>---------------------------------------------------------------

commit d712369dc82432f3ffc20443cb3d3a8efbc0b934
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sun Jul 5 14:44:24 2015 +0200

    Set the current evaluation mode.


>---------------------------------------------------------------

d712369dc82432f3ffc20443cb3d3a8efbc0b934
 src/mlpack/methods/ann/cnn.hpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp
index bcd08a8..b5700d1 100644
--- a/src/mlpack/methods/ann/cnn.hpp
+++ b/src/mlpack/methods/ann/cnn.hpp
@@ -65,6 +65,7 @@ class CNN
                      const OutputType& target,
                      ErrorType& error)
     {
+      deterministic = false;
       seqNum++;
       trainError += Evaluate(input, target, error);
     }
@@ -106,6 +107,7 @@ class CNN
     template <typename InputType, typename OutputType>
     void Predict(const InputType& input, OutputType& output)
     {
+      deterministic = true;
       ResetActivations(network);
 
       std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -128,6 +130,7 @@ class CNN
                     const OutputType& target,
                     ErrorType& error)
     {
+      deterministic = false;
       ResetActivations(network);
 
       std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
@@ -174,6 +177,7 @@ class CNN
     typename std::enable_if<I < sizeof...(Tp), void>::type
     Reset(std::tuple<Tp...>& t)
     {
+      std::get<I>(t).OutputLayer().Deterministic() = deterministic;
       std::get<I>(t).OutputLayer().InputActivation().zeros();
       std::get<I>(t).Delta().zeros();
       Reset<I + 1, Tp...>(t);
@@ -455,6 +459,9 @@ class CNN
 
     //! The number of the current input sequence.
     size_t seqNum;
+
+    //! The current evaluation mode (training or testing).
+    bool deterministic;
 }; // class CNN
 
 



More information about the mlpack-git mailing list