[mlpack-git] master: Add the ability to define the training parameters. (62d8f6d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:13:43 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 62d8f6db76d197e1dc60d64181151265d987520c
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jan 20 11:30:00 2015 +0100

    Add the ability to define the training parameters.


>---------------------------------------------------------------

62d8f6db76d197e1dc60d64181151265d987520c
 src/mlpack/methods/ann/trainer/trainer.hpp | 30 +++++++++++++++++++++++-------
 1 file changed, 23 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 6579efb..f5b42a3 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -43,9 +43,10 @@ class Trainer
      * each epoch (Default 1).
      *
      * @param net The network that should be trained.
-     * @param maxEpochs The number of maximal trained iterations.
+     * @param maxEpochs The number of maximal trained iterations (0 means no
+     * limit).
      * @param batchSize The batch size used to train the network.
-     * @param convergenceThreshold Train the network until it converges against
+     * @param tolerance Train the network until it converges against
      * the specified threshold.
      * @param shuffle If true, the order of the training set is shuffled;
      * otherwise, each data is visited in linear order.
@@ -53,12 +54,12 @@ class Trainer
     Trainer(NetworkType& net,
             const size_t maxEpochs = 0,
             const size_t batchSize = 1,
-            const double convergenceThreshold  = 0.0001,
+            const double tolerance = 0.0001,
             const bool shuffle = true) :
         net(net),
         maxEpochs(maxEpochs),
         batchSize(batchSize),
-        convergenceThreshold(convergenceThreshold),
+        tolerance(tolerance),
         shuffle(shuffle)
     {
       // Nothing to do here.
@@ -94,7 +95,7 @@ class Trainer
         Train(trainingData, trainingLabels);
         Evaluate(validationData, validationLabels);
 
-        if (validationError <= convergenceThreshold)
+        if (validationError <= tolerance)
           break;
 
         if (maxEpochs > 0 && ++epoch > maxEpochs)
@@ -113,6 +114,21 @@ class Trainer
     //! Modify whether or not the individual inputs are shuffled.
     bool& Shuffle() { return shuffle; }
 
+    //! Get the batch size.
+    size_t StepSize() const { return batchSize; }
+    //! Modify the batch size.
+    size_t& StepSize() { return batchSize; }
+
+    //! Get the maximum number of iterations (0 indicates no limit).
+    size_t MaxEpochs() const { return maxEpochs; }
+    //! Modify the maximum number of iterations (0 indicates no limit).
+    size_t& MaxEpochs() { return maxEpochs; }
+
+    //! Get the tolerance for termination.
+    double Tolerance() const { return tolerance; }
+    //! Modify the tolerance for termination.
+    double& Tolerance() { return tolerance; }
+
   private:
     /**
      * Train the network on the given dataset.
@@ -188,8 +204,8 @@ class Trainer
     //! The overall validation error.
     double validationError;
 
-    //! The threshold used as convergence.
-    double convergenceThreshold;
+    //! The tolerance for termination.
+    double tolerance;
 
     //! Controls whether or not the individual inputs are shuffled when
     //! iterating.



More information about the mlpack-git mailing list