[mlpack-git] master: Refactor to support convolutional neural networks. (2f479f3)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 2 04:49:46 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/61d7876048f2208cf45d41d71f9d4baa825e2a51...2f479f388ee3d34e4a20535c3662b1921a4c6c06

>---------------------------------------------------------------

commit 2f479f388ee3d34e4a20535c3662b1921a4c6c06
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 2 00:46:06 2015 +0200

    Refactor to support convolutional neural networks.


>---------------------------------------------------------------

2f479f388ee3d34e4a20535c3662b1921a4c6c06
 src/mlpack/methods/ann/trainer/trainer.hpp | 26 +++++++++++++++++++++-----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index cdb948f..1f1aa4d 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -188,21 +188,38 @@ class Trainer
      * @param sliceNum Provide a Col object of the specified index.
      */
     template<typename eT>
-    arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+    typename std::enable_if<!NetworkTraits<NetworkType>::IsCNN,
+        arma::Col<eT> >::type
+    Element(arma::Mat<eT>& input, const size_t colNum)
     {
       return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
     }
 
     /*
+     * Create a Mat object which uses memory from an existing matrix object.
+     * (This approach is currently not alias safe)
+     *
+     * @param data The reference data.
+     * @param sliceNum Provide a Mat object of the specified index.
+     */
+    template<typename eT>
+    typename std::enable_if<NetworkTraits<NetworkType>::IsCNN,
+        arma::Mat<eT> >::type
+    Element(arma::Mat<eT>& input, const size_t colNum)
+    {
+      return arma::Mat<eT>(input.colptr(colNum), input.n_rows, 1, false, true);
+    }
+
+    /*
      * Provide the reference to the matrix representing a single slice.
      *
      * @param data The reference data.
      * @param sliceNum Provide a single slice of the specified index.
      */
     template<typename eT>
-    const arma::Mat<eT>& Element(arma::Cube<eT>& input, const size_t sliceNum)
+    const arma::Cube<eT> Element(arma::Cube<eT>& input, const size_t sliceNum)
     {
-      return *(input.mat_ptrs[sliceNum]);
+      return input.slices(sliceNum, sliceNum);
     }
 
     /*
@@ -231,8 +248,7 @@ class Trainer
     NetworkType& net;
 
     //! The current network error of a single input.
-    typename std::conditional<NetworkTraits<NetworkType>::IsFNN ||
-                              NetworkTraits<NetworkType>::IsCNN,
+    typename std::conditional<NetworkTraits<NetworkType>::IsFNN,
         VecType, MatType>::type error;
 
     //! The current epoch if maxEpochs is set.



More information about the mlpack-git mailing list