[mlpack-git] master: Add the ability to define the training parameters. (62d8f6d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:13:43 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 62d8f6db76d197e1dc60d64181151265d987520c
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jan 20 11:30:00 2015 +0100
Add the ability to define the training parameters.
>---------------------------------------------------------------
62d8f6db76d197e1dc60d64181151265d987520c
src/mlpack/methods/ann/trainer/trainer.hpp | 30 +++++++++++++++++++++++-------
1 file changed, 23 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 6579efb..f5b42a3 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -43,9 +43,10 @@ class Trainer
* each epoch (Default 1).
*
* @param net The network that should be trained.
- * @param maxEpochs The number of maximal trained iterations.
+ * @param maxEpochs The number of maximal trained iterations (0 means no
+ * limit).
* @param batchSize The batch size used to train the network.
- * @param convergenceThreshold Train the network until it converges against
+ * @param tolerance Train the network until it converges against
* the specified threshold.
* @param shuffle If true, the order of the training set is shuffled;
* otherwise, each data is visited in linear order.
@@ -53,12 +54,12 @@ class Trainer
Trainer(NetworkType& net,
const size_t maxEpochs = 0,
const size_t batchSize = 1,
- const double convergenceThreshold = 0.0001,
+ const double tolerance = 0.0001,
const bool shuffle = true) :
net(net),
maxEpochs(maxEpochs),
batchSize(batchSize),
- convergenceThreshold(convergenceThreshold),
+ tolerance(tolerance),
shuffle(shuffle)
{
// Nothing to do here.
@@ -94,7 +95,7 @@ class Trainer
Train(trainingData, trainingLabels);
Evaluate(validationData, validationLabels);
- if (validationError <= convergenceThreshold)
+ if (validationError <= tolerance)
break;
if (maxEpochs > 0 && ++epoch > maxEpochs)
@@ -113,6 +114,21 @@ class Trainer
//! Modify whether or not the individual inputs are shuffled.
bool& Shuffle() { return shuffle; }
+ //! Get the batch size.
+ size_t StepSize() const { return batchSize; }
+ //! Modify the batch size.
+ size_t& StepSize() { return batchSize; }
+
+ //! Get the maximum number of iterations (0 indicates no limit).
+ size_t MaxEpochs() const { return maxEpochs; }
+ //! Modify the maximum number of iterations (0 indicates no limit).
+ size_t& MaxEpochs() { return maxEpochs; }
+
+ //! Get the tolerance for termination.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance for termination.
+ double& Tolerance() { return tolerance; }
+
private:
/**
* Train the network on the given dataset.
@@ -188,8 +204,8 @@ class Trainer
//! The overall validation error.
double validationError;
- //! The threshold used as convergence.
- double convergenceThreshold;
+ //! The tolerance for termination.
+ double tolerance;
//! Controls whether or not the individual inputs are shuffled when
//! iterating.
More information about the mlpack-git
mailing list