[mlpack-git] master: Basically you don't need to shuffel the input in batch mode; However at least there is the option now. (e2a8093)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:13:41 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit e2a80939af7964219d52b64a5f630ab280148837
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jan 20 10:38:39 2015 +0100

    Basically you don't need to shuffel the input in batch mode; However at least there is the option now.


>---------------------------------------------------------------

e2a80939af7964219d52b64a5f630ab280148837
 src/mlpack/methods/ann/trainer/trainer.hpp | 25 ++++++++++++++++++++-----
 1 file changed, 20 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 6cee560..6579efb 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -47,15 +47,19 @@ class Trainer
      * @param batchSize The batch size used to train the network.
      * @param convergenceThreshold Train the network until it converges against
      * the specified threshold.
+     * @param shuffle If true, the order of the training set is shuffled;
+     * otherwise, each data is visited in linear order.
      */
     Trainer(NetworkType& net,
             const size_t maxEpochs = 0,
             const size_t batchSize = 1,
-            const double convergenceThreshold  = 0.0001) :
+            const double convergenceThreshold  = 0.0001,
+            const bool shuffle = true) :
         net(net),
         maxEpochs(maxEpochs),
         batchSize(batchSize),
-        convergenceThreshold(convergenceThreshold)
+        convergenceThreshold(convergenceThreshold),
+        shuffle(shuffle)
     {
       // Nothing to do here.
     }
@@ -82,8 +86,10 @@ class Trainer
 
       while(true)
       {
-        // Randomly shuffle the index sequence.
-        index = arma::shuffle(index);
+
+        // Randomly shuffle the index sequence if not in batch mode.
+        if (shuffle)
+          index = arma::shuffle(index);
 
         Train(trainingData, trainingLabels);
         Evaluate(validationData, validationLabels);
@@ -102,6 +108,11 @@ class Trainer
     //! Get the validation error.
     double ValidationError() const { return validationError; }
 
+    //! Get whether or not the individual inputs are shuffled.
+    bool Shuffle() const { return shuffle; }
+    //! Modify whether or not the individual inputs are shuffled.
+    bool& Shuffle() { return shuffle; }
+
   private:
     /**
      * Train the network on the given dataset.
@@ -168,7 +179,7 @@ class Trainer
     //! The size until a update is performed.
     size_t batchSize;
 
-    //! Index sequence used to train the network.
+    //! The shuffel sequence index used to train the network.
     arma::Col<size_t> index;
 
     //! The overall traing error.
@@ -179,6 +190,10 @@ class Trainer
 
     //! The threshold used as convergence.
     double convergenceThreshold;
+
+    //! Controls whether or not the individual inputs are shuffled when
+    //! iterating.
+    bool shuffle;
 }; // class Trainer
 
 }; // namespace ann



More information about the mlpack-git mailing list