[mlpack-git] master: Add support for 3rd order tensors as input to train neural networks. (f349aa7)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Mar 10 17:55:41 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5e2f4f2e6b531af5791e185755d7517bd3e80a62...f349aa79f97f893b2d98f5c795bff3751c9be71a
>---------------------------------------------------------------
commit f349aa79f97f893b2d98f5c795bff3751c9be71a
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Mar 10 22:55:34 2015 +0100
Add support for 3rd order tensors as input to train neural networks.
>---------------------------------------------------------------
f349aa79f97f893b2d98f5c795bff3751c9be71a
src/mlpack/methods/ann/trainer/trainer.hpp | 82 +++++++++++++++++++++++++-----
1 file changed, 68 insertions(+), 14 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 01c1301..d2724e6 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -74,15 +74,13 @@ class Trainer
* @param validationData Data used to evaluate the network.
* @tparam validationLabels Labels used to evaluate the network.
*/
- void Train(MatType& trainingData,
- MatType& trainingLabels,
- MatType& validationData,
- MatType& validationLabels)
+ template<typename InputType, typename OutputType>
+ void Train(InputType& trainingData,
+ OutputType& trainingLabels,
+ InputType& validationData,
+ OutputType& validationLabels)
{
- // This generates [0 1 2 3 ... (trainingData.n_cols - 1)]. The sequence
- // will be used to iterate through the training data.
- index = arma::linspace<arma::Col<size_t> >(0, trainingData.n_cols - 1,
- trainingData.n_cols);
+ Index(trainingData);
epoch = 0;
while(true)
@@ -134,15 +132,16 @@ class Trainer
* @param data Data used to train the network.
* @param target Labels used to train the network.
*/
- void Train(MatType& data, MatType& target)
+ template<typename InputType, typename OutputType>
+ void Train(InputType& data, OutputType& target)
{
// Reset the training error.
trainingError = 0;
for (size_t i = 0; i < data.n_cols; i++)
{
- net.FeedForward(data.unsafe_col(index(i)),
- target.unsafe_col(index(i)), error);
+ net.FeedForward(Element(data, index(i)),
+ Element(target, index(i)), error);
trainingError += net.Error();
net.FeedBackward(error);
@@ -163,20 +162,75 @@ class Trainer
* @param data Data used to train the network.
* @param target Labels used to train the network.
*/
- void Evaluate(MatType& data, MatType& target)
+ template<typename InputType, typename OutputType>
+ void Evaluate(InputType& data, OutputType& target)
{
// Reset the validation error.
validationError = 0;
for (size_t i = 0; i < data.n_cols; i++)
{
- validationError += net.Evaluate(data.unsafe_col(i),
- target.unsafe_col(i), error);
+ validationError += net.Evaluate(Element(data, i),
+ Element(target, i), error);
}
validationError /= data.n_cols;
}
+ /*
+ * Generate index sequence to iterate through the data.
+ *
+ * @param input The reference data.
+ */
+ template<typename eT>
+ void Index(const arma::Mat<eT>& input)
+ {
+ // This generates [0 1 2 3 ... (input.n_cols - 1)]. The sequence
+ // will be used to iterate through the training data.
+ index = arma::linspace<arma::Col<size_t> >(0, input.n_cols - 1,
+ input.n_cols);
+ }
+
+ /*
+ * Generate index sequence to iterate through the data.
+ *
+ * @param input The reference data.
+ */
+ template<typename eT>
+ void Index(const arma::Cube<eT>& input)
+ {
+ // This generates [0 1 2 3 ... (input.n_slices - 1)]. The sequence
+ // will be used to iterate through the training data.
+ index = arma::linspace<arma::Col<size_t> >(0, input.n_slices - 1,
+ input.n_slices);
+ }
+
+ /*
+ * Create a Col object which uses memory from an existing matrix object.
+ * (This approach is currently not alias safe)
+ *
+ * @param data The reference data.
+ * @param sliceNum Provide a Col object of the specified index.
+ */
+ template<typename eT>
+ arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+ {
+ return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
+ }
+
+ /*
+ * Provide the reference to the matrix representing a single slice.
+ *
+ * @param data The reference data.
+ * @param sliceNum Provide a single slice of the specified index.
+ */
+ template<typename eT>
+ const arma::Mat<eT>& Element(arma::Cube<eT>& input,
+ const size_t sliceNum)
+ {
+ return *(input.mat_ptrs[sliceNum]);
+ }
+
//! The network which should be trained and evaluated.
NetworkType& net;
More information about the mlpack-git
mailing list