[mlpack-git] master: Add support for 3rd order tensors as input to train neural networks. (f349aa7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Mar 10 17:55:41 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/5e2f4f2e6b531af5791e185755d7517bd3e80a62...f349aa79f97f893b2d98f5c795bff3751c9be71a

>---------------------------------------------------------------

commit f349aa79f97f893b2d98f5c795bff3751c9be71a
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Mar 10 22:55:34 2015 +0100

    Add support for 3rd order tensors as input to train neural networks.


>---------------------------------------------------------------

f349aa79f97f893b2d98f5c795bff3751c9be71a
 src/mlpack/methods/ann/trainer/trainer.hpp | 82 +++++++++++++++++++++++++-----
 1 file changed, 68 insertions(+), 14 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 01c1301..d2724e6 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -74,15 +74,13 @@ class Trainer
      * @param validationData Data used to evaluate the network.
      * @tparam validationLabels Labels used to evaluate the network.
      */
-    void Train(MatType& trainingData,
-               MatType& trainingLabels,
-               MatType& validationData,
-               MatType& validationLabels)
+    template<typename InputType, typename OutputType>
+    void Train(InputType& trainingData,
+               OutputType& trainingLabels,
+               InputType& validationData,
+               OutputType& validationLabels)
     {
-      // This generates [0 1 2 3 ... (trainingData.n_cols - 1)]. The sequence
-      // will be used to iterate through the training data.
-      index = arma::linspace<arma::Col<size_t> >(0, trainingData.n_cols - 1,
-          trainingData.n_cols);
+      Index(trainingData);
       epoch = 0;
 
       while(true)
@@ -134,15 +132,16 @@ class Trainer
      * @param data Data used to train the network.
      * @param target Labels used to train the network.
      */
-    void Train(MatType& data, MatType& target)
+    template<typename InputType, typename OutputType>
+    void Train(InputType& data, OutputType& target)
     {
       // Reset the training error.
       trainingError = 0;
 
       for (size_t i = 0; i < data.n_cols; i++)
       {
-        net.FeedForward(data.unsafe_col(index(i)),
-            target.unsafe_col(index(i)), error);
+        net.FeedForward(Element(data, index(i)),
+            Element(target, index(i)), error);
 
         trainingError += net.Error();
         net.FeedBackward(error);
@@ -163,20 +162,75 @@ class Trainer
      * @param data Data used to train the network.
      * @param target Labels used to train the network.
      */
-    void Evaluate(MatType& data, MatType& target)
+    template<typename InputType, typename OutputType>
+    void Evaluate(InputType& data, OutputType& target)
     {
       // Reset the validation error.
       validationError = 0;
 
       for (size_t i = 0; i < data.n_cols; i++)
       {
-         validationError += net.Evaluate(data.unsafe_col(i),
-            target.unsafe_col(i), error);
+         validationError += net.Evaluate(Element(data, i),
+            Element(target, i), error);
       }
 
       validationError /= data.n_cols;
     }
 
+    /*
+     * Generate index sequence to iterate through the data.
+     *
+     * @param input The reference data.
+     */
+    template<typename eT>
+    void Index(const arma::Mat<eT>& input)
+    {
+      // This generates [0 1 2 3 ... (input.n_cols - 1)]. The sequence
+      // will be used to iterate through the training data.
+      index = arma::linspace<arma::Col<size_t> >(0, input.n_cols - 1,
+          input.n_cols);
+    }
+
+    /*
+     * Generate index sequence to iterate through the data.
+     *
+     * @param input The reference data.
+     */
+    template<typename eT>
+    void Index(const arma::Cube<eT>& input)
+    {
+      // This generates [0 1 2 3 ... (input.n_slices - 1)]. The sequence
+      // will be used to iterate through the training data.
+      index = arma::linspace<arma::Col<size_t> >(0, input.n_slices - 1,
+          input.n_slices);
+    }
+
+    /*
+     * Create a Col object which uses memory from an existing matrix object.
+     * (This approach is currently not alias safe)
+     *
+     * @param data The reference data.
+     * @param sliceNum Provide a Col object of the specified index.
+     */
+    template<typename eT>
+    arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+    {
+      return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
+    }
+
+    /*
+     * Provide the reference to the matrix representing a single slice.
+     *
+     * @param data The reference data.
+     * @param sliceNum Provide a single slice of the specified index.
+     */
+    template<typename eT>
+    const arma::Mat<eT>& Element(arma::Cube<eT>& input,
+        const size_t sliceNum)
+    {
+      return *(input.mat_ptrs[sliceNum]);
+    }
+
     //! The network which should be trained and evaluated.
     NetworkType& net;
 



More information about the mlpack-git mailing list