[mlpack-git] master: Basically you don't need to shuffel the input in batch mode; However at least there is the option now. (e2a8093)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:13:41 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit e2a80939af7964219d52b64a5f630ab280148837
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jan 20 10:38:39 2015 +0100
Basically you don't need to shuffel the input in batch mode; However at least there is the option now.
>---------------------------------------------------------------
e2a80939af7964219d52b64a5f630ab280148837
src/mlpack/methods/ann/trainer/trainer.hpp | 25 ++++++++++++++++++++-----
1 file changed, 20 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 6cee560..6579efb 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -47,15 +47,19 @@ class Trainer
* @param batchSize The batch size used to train the network.
* @param convergenceThreshold Train the network until it converges against
* the specified threshold.
+ * @param shuffle If true, the order of the training set is shuffled;
+ * otherwise, each data is visited in linear order.
*/
Trainer(NetworkType& net,
const size_t maxEpochs = 0,
const size_t batchSize = 1,
- const double convergenceThreshold = 0.0001) :
+ const double convergenceThreshold = 0.0001,
+ const bool shuffle = true) :
net(net),
maxEpochs(maxEpochs),
batchSize(batchSize),
- convergenceThreshold(convergenceThreshold)
+ convergenceThreshold(convergenceThreshold),
+ shuffle(shuffle)
{
// Nothing to do here.
}
@@ -82,8 +86,10 @@ class Trainer
while(true)
{
- // Randomly shuffle the index sequence.
- index = arma::shuffle(index);
+
+ // Randomly shuffle the index sequence if not in batch mode.
+ if (shuffle)
+ index = arma::shuffle(index);
Train(trainingData, trainingLabels);
Evaluate(validationData, validationLabels);
@@ -102,6 +108,11 @@ class Trainer
//! Get the validation error.
double ValidationError() const { return validationError; }
+ //! Get whether or not the individual inputs are shuffled.
+ bool Shuffle() const { return shuffle; }
+ //! Modify whether or not the individual inputs are shuffled.
+ bool& Shuffle() { return shuffle; }
+
private:
/**
* Train the network on the given dataset.
@@ -168,7 +179,7 @@ class Trainer
//! The size until a update is performed.
size_t batchSize;
- //! Index sequence used to train the network.
+ //! The shuffel sequence index used to train the network.
arma::Col<size_t> index;
//! The overall traing error.
@@ -179,6 +190,10 @@ class Trainer
//! The threshold used as convergence.
double convergenceThreshold;
+
+ //! Controls whether or not the individual inputs are shuffled when
+ //! iterating.
+ bool shuffle;
}; // class Trainer
}; // namespace ann
More information about the mlpack-git
mailing list