[mlpack-git] master: Use the correct number of epochs. (b6d4a7b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Feb 27 15:51:52 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/594fd9f61d1280152c758559b4fc60bf0c827cca...45f682337b1daa4c82797f950e16a605fe4971bd
>---------------------------------------------------------------
commit b6d4a7bfaa1f15a10959fad18c99149ec2ce98f6
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Fri Feb 27 21:10:39 2015 +0100
Use the correct number of epochs.
>---------------------------------------------------------------
b6d4a7bfaa1f15a10959fad18c99149ec2ce98f6
src/mlpack/methods/ann/trainer/trainer.hpp | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index f5b42a3..01c1301 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -87,8 +87,6 @@ class Trainer
while(true)
{
-
- // Randomly shuffle the index sequence if not in batch mode.
if (shuffle)
index = arma::shuffle(index);
@@ -98,7 +96,7 @@ class Trainer
if (validationError <= tolerance)
break;
- if (maxEpochs > 0 && ++epoch > maxEpochs)
+ if (maxEpochs > 0 && ++epoch >= maxEpochs)
break;
}
}
@@ -145,8 +143,8 @@ class Trainer
{
net.FeedForward(data.unsafe_col(index(i)),
target.unsafe_col(index(i)), error);
- trainingError += net.Error();
+ trainingError += net.Error();
net.FeedBackward(error);
if (((i + 1) % batchSize) == 0)
@@ -172,8 +170,8 @@ class Trainer
for (size_t i = 0; i < data.n_cols; i++)
{
- net.FeedForward(data.unsafe_col(i), target.unsafe_col(i), error);
- validationError += net.Error();
+ validationError += net.Evaluate(data.unsafe_col(i),
+ target.unsafe_col(i), error);
}
validationError /= data.n_cols;
More information about the mlpack-git
mailing list