[mlpack-git] master: Add Trainer class which serves as container to train the networks. (3739596)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat Jan 10 07:31:25 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/059a9b6e7da15e25151daae43e4a41d235f8c84c...37395966b16172f2ac2c7dbeba5ec13e2e37659d
>---------------------------------------------------------------
commit 37395966b16172f2ac2c7dbeba5ec13e2e37659d
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Jan 10 13:31:19 2015 +0100
Add Trainer class which serves as container to train the networks.
>---------------------------------------------------------------
37395966b16172f2ac2c7dbeba5ec13e2e37659d
src/mlpack/methods/ann/ffnn.hpp | 4 +-
src/mlpack/methods/ann/trainer/trainer.hpp | 187 +++++++++++++++++++++++++++++
2 files changed, 189 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/ann/ffnn.hpp b/src/mlpack/methods/ann/ffnn.hpp
index 9611c85..56c35df 100644
--- a/src/mlpack/methods/ann/ffnn.hpp
+++ b/src/mlpack/methods/ann/ffnn.hpp
@@ -25,7 +25,7 @@ namespace ann /** Artificial Neural Network. */ {
* be used to construct the network.
* @tparam OutputLayerType The outputlayer type used to evaluate the network.
* @tparam PerformanceFunction Performance strategy used to claculate the error.
- * @tparam MaType of gradients. (arma::mat or arma::sp_mat).
+ * @tparam MaType Type of the gradients. (arma::mat or arma::sp_mat).
*/
template <
typename ConnectionTypes,
@@ -54,7 +54,7 @@ class FFNN
* input and target vector, updating the resulting error into the error
* vector.
*
- * @param input Input data used to evaluat the network.
+ * @param input Input data used to evaluate the network.
* @param target Target data used to calculate the network error.
* @param error The calulated error of the output layer.
* @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
new file mode 100644
index 0000000..6cee560
--- /dev/null
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -0,0 +1,187 @@
+/**
+ * @file trainer.hpp
+ * @author Marcus Edel
+ *
+ * Definition and implementation of a trainer that trains the parameters of a
+ * neural network according to a supervised dataset.
+ */
+#ifndef __MLPACK_METHODS_ANN_TRAINER_TRAINER_HPP
+#define __MLPACK_METHODS_ANN_TRAINER_TRAINER_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/methods/ann/network_traits.hpp>
+#include <mlpack/methods/ann/layer/layer_traits.hpp>
+#include <mlpack/methods/ann/layer/neuron_layer.hpp>
+
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+
+/**
+ * Trainer that trains the parameters of a neural network according to a
+ * supervised dataset.
+ *
+ * @tparam NetworkType The type of network which should be trained and
+ * evaluated.
+ * @tparam MaType Type of the error type (arma::mat or arma::sp_mat).
+ * @tparam VecType Type of error type (arma::colvec, arma::mat or arma::sp_mat).
+ */
+template<
+ typename NetworkType,
+ typename MatType = arma::mat,
+ typename VecType = arma::colvec
+>
+class Trainer
+{
+ public:
+ /**
+ * Construct the Trainer object, which will be used to train a neural
+ * network according to a supervised dataset by backpropagating the errors.
+ *
+ * If batchSize is greater 1 the trainer will take a mean gradient step over
+ * this many samples and will update the parameters only at the end of
+ * each epoch (Default 1).
+ *
+ * @param net The network that should be trained.
+ * @param maxEpochs The number of maximal trained iterations.
+ * @param batchSize The batch size used to train the network.
+ * @param convergenceThreshold Train the network until it converges against
+ * the specified threshold.
+ */
+ Trainer(NetworkType& net,
+ const size_t maxEpochs = 0,
+ const size_t batchSize = 1,
+ const double convergenceThreshold = 0.0001) :
+ net(net),
+ maxEpochs(maxEpochs),
+ batchSize(batchSize),
+ convergenceThreshold(convergenceThreshold)
+ {
+ // Nothing to do here.
+ }
+
+ /**
+ * Train the network on the given datasets until the network converges. If
+ * maxEpochs is greater than zero that many epochs are maximal trained.
+ *
+ * @param trainingData Data used to train the network.
+ * @param trainingLabels Labels used to train the network.
+ * @param validationData Data used to evaluate the network.
+ * @tparam validationLabels Labels used to evaluate the network.
+ */
+ void Train(MatType& trainingData,
+ MatType& trainingLabels,
+ MatType& validationData,
+ MatType& validationLabels)
+ {
+ // This generates [0 1 2 3 ... (trainingData.n_cols - 1)]. The sequence
+ // will be used to iterate through the training data.
+ index = arma::linspace<arma::Col<size_t> >(0, trainingData.n_cols - 1,
+ trainingData.n_cols);
+ epoch = 0;
+
+ while(true)
+ {
+ // Randomly shuffle the index sequence.
+ index = arma::shuffle(index);
+
+ Train(trainingData, trainingLabels);
+ Evaluate(validationData, validationLabels);
+
+ if (validationError <= convergenceThreshold)
+ break;
+
+ if (maxEpochs > 0 && ++epoch > maxEpochs)
+ break;
+ }
+ }
+
+ //! Get the training error.
+ double TrainingError() const { return trainingError; }
+
+ //! Get the validation error.
+ double ValidationError() const { return validationError; }
+
+ private:
+ /**
+ * Train the network on the given dataset.
+ *
+ * @param data Data used to train the network.
+ * @param target Labels used to train the network.
+ */
+ void Train(MatType& data, MatType& target)
+ {
+ // Reset the training error.
+ trainingError = 0;
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ net.FeedForward(data.unsafe_col(index(i)),
+ target.unsafe_col(index(i)), error);
+ trainingError += net.Error();
+
+ net.FeedBackward(error);
+
+ if (((i + 1) % batchSize) == 0)
+ net.ApplyGradients();
+ }
+
+ if ((data.n_cols % batchSize) != 0)
+ net.ApplyGradients();
+
+ trainingError /= data.n_cols;
+ }
+
+ /**
+ * Evaluate the network on the given dataset.
+ *
+ * @param data Data used to train the network.
+ * @param target Labels used to train the network.
+ */
+ void Evaluate(MatType& data, MatType& target)
+ {
+ // Reset the validation error.
+ validationError = 0;
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ net.FeedForward(data.unsafe_col(i), target.unsafe_col(i), error);
+ validationError += net.Error();
+ }
+
+ validationError /= data.n_cols;
+ }
+
+ //! The network which should be trained and evaluated.
+ NetworkType& net;
+
+ //! The current network error of a single input.
+ typename std::conditional<NetworkTraits<NetworkType>::IsFNN,
+ VecType, MatType>::type error;
+
+ //! The current epoch if maxEpochs is set.
+ size_t epoch;
+
+ //! The maximal epochs that should be used.
+ size_t maxEpochs;
+
+ //! The size until a update is performed.
+ size_t batchSize;
+
+ //! Index sequence used to train the network.
+ arma::Col<size_t> index;
+
+ //! The overall traing error.
+ double trainingError;
+
+ //! The overall validation error.
+ double validationError;
+
+ //! The threshold used as convergence.
+ double convergenceThreshold;
+}; // class Trainer
+
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list