[mlpack-git] master: Use the correct number of epochs. (b6d4a7b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Feb 27 15:51:52 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/594fd9f61d1280152c758559b4fc60bf0c827cca...45f682337b1daa4c82797f950e16a605fe4971bd

>---------------------------------------------------------------

commit b6d4a7bfaa1f15a10959fad18c99149ec2ce98f6
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Feb 27 21:10:39 2015 +0100

    Use the correct number of epochs.


>---------------------------------------------------------------

b6d4a7bfaa1f15a10959fad18c99149ec2ce98f6
 src/mlpack/methods/ann/trainer/trainer.hpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index f5b42a3..01c1301 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -87,8 +87,6 @@ class Trainer
 
       while(true)
       {
-
-        // Randomly shuffle the index sequence if not in batch mode.
         if (shuffle)
           index = arma::shuffle(index);
 
@@ -98,7 +96,7 @@ class Trainer
         if (validationError <= tolerance)
           break;
 
-        if (maxEpochs > 0 && ++epoch > maxEpochs)
+        if (maxEpochs > 0 && ++epoch >= maxEpochs)
           break;
       }
     }
@@ -145,8 +143,8 @@ class Trainer
       {
         net.FeedForward(data.unsafe_col(index(i)),
             target.unsafe_col(index(i)), error);
-        trainingError += net.Error();
 
+        trainingError += net.Error();
         net.FeedBackward(error);
 
         if (((i + 1) % batchSize) == 0)
@@ -172,8 +170,8 @@ class Trainer
 
       for (size_t i = 0; i < data.n_cols; i++)
       {
-        net.FeedForward(data.unsafe_col(i), target.unsafe_col(i), error);
-        validationError += net.Error();
+         validationError += net.Evaluate(data.unsafe_col(i),
+            target.unsafe_col(i), error);
       }
 
       validationError /= data.n_cols;



More information about the mlpack-git mailing list