[mlpack-git] master: Refactor to support convolutional neural networks. (2f479f3)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Jun 2 04:49:46 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/61d7876048f2208cf45d41d71f9d4baa825e2a51...2f479f388ee3d34e4a20535c3662b1921a4c6c06
>---------------------------------------------------------------
commit 2f479f388ee3d34e4a20535c3662b1921a4c6c06
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 2 00:46:06 2015 +0200
Refactor to support convolutional neural networks.
>---------------------------------------------------------------
2f479f388ee3d34e4a20535c3662b1921a4c6c06
src/mlpack/methods/ann/trainer/trainer.hpp | 26 +++++++++++++++++++++-----
1 file changed, 21 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index cdb948f..1f1aa4d 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -188,21 +188,38 @@ class Trainer
* @param sliceNum Provide a Col object of the specified index.
*/
template<typename eT>
- arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+ typename std::enable_if<!NetworkTraits<NetworkType>::IsCNN,
+ arma::Col<eT> >::type
+ Element(arma::Mat<eT>& input, const size_t colNum)
{
return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
}
/*
+ * Create a Mat object which uses memory from an existing matrix object.
+ * (This approach is currently not alias safe)
+ *
+ * @param data The reference data.
+ * @param sliceNum Provide a Mat object of the specified index.
+ */
+ template<typename eT>
+ typename std::enable_if<NetworkTraits<NetworkType>::IsCNN,
+ arma::Mat<eT> >::type
+ Element(arma::Mat<eT>& input, const size_t colNum)
+ {
+ return arma::Mat<eT>(input.colptr(colNum), input.n_rows, 1, false, true);
+ }
+
+ /*
* Provide the reference to the matrix representing a single slice.
*
* @param data The reference data.
* @param sliceNum Provide a single slice of the specified index.
*/
template<typename eT>
- const arma::Mat<eT>& Element(arma::Cube<eT>& input, const size_t sliceNum)
+ const arma::Cube<eT> Element(arma::Cube<eT>& input, const size_t sliceNum)
{
- return *(input.mat_ptrs[sliceNum]);
+ return input.slices(sliceNum, sliceNum);
}
/*
@@ -231,8 +248,7 @@ class Trainer
NetworkType& net;
//! The current network error of a single input.
- typename std::conditional<NetworkTraits<NetworkType>::IsFNN ||
- NetworkTraits<NetworkType>::IsCNN,
+ typename std::conditional<NetworkTraits<NetworkType>::IsFNN,
VecType, MatType>::type error;
//! The current epoch if maxEpochs is set.
More information about the mlpack-git
mailing list