[mlpack-git] master: Refactor to support 3rd-order tensors. (60a4d92)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:21 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d
>---------------------------------------------------------------
commit 60a4d92e014789ebd760d79047e9bdf1e2d854d7
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Mon Jun 22 17:09:11 2015 +0200
Refactor to support 3rd-order tensors.
>---------------------------------------------------------------
60a4d92e014789ebd760d79047e9bdf1e2d854d7
.../activation_functions/rectifier_function.hpp | 27 ++++++++++++++++------
1 file changed, 20 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp b/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
index 3eaa2de..ba14a14 100644
--- a/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
+++ b/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
@@ -52,16 +52,29 @@ class RectifierFunction
}
/**
- * Computes the rectifier function.
+ * Computes the rectifier function using a dense matrix as input.
+ *
+ * @param x Input data.
+ * @param y The resulting output activation.
+ */
+ template<typename eT>
+ static void fn(const arma::Mat<eT>& x, arma::Mat<eT>& y)
+ {
+ y = arma::max(arma::zeros<arma::Mat<eT> >(x.n_rows, x.n_cols), x);
+ }
+
+ /**
+ * Computes the rectifier function using a 3rd-order tensor as input.
*
* @param x Input data.
* @param y The resulting output activation.
*/
- template<typename InputVecType, typename OutputVecType>
- static void fn(const InputVecType& x, OutputVecType& y)
+ template<typename eT>
+ static void fn(const arma::Cube<eT>& x, arma::Cube<eT>& y)
{
y = x;
- y = arma::max(arma::zeros<OutputVecType>(x.n_elem), x);
+ for (size_t s = 0; s < x.n_slices; s++)
+ fn(x.slice(s), y.slice(s));
}
/**
@@ -72,7 +85,7 @@ class RectifierFunction
*/
static double deriv(const double y)
{
- return y > 0 ? 1 : 0;
+ return y > 0;
}
/**
@@ -81,8 +94,8 @@ class RectifierFunction
* @param y Input activations.
* @param x The resulting derivatives.
*/
- template<typename InputVecType, typename OutputVecType>
- static void deriv(const InputVecType& y, OutputVecType& x)
+ template<typename InputType, typename OutputType>
+ static void deriv(const InputType& y, OutputType& x)
{
x = y;
x.transform( [](double y) { return deriv(y); } );
More information about the mlpack-git
mailing list