[mlpack-git] master: Refactor to support 3rd-order tensors. (60a4d92)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:21 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d

>---------------------------------------------------------------

commit 60a4d92e014789ebd760d79047e9bdf1e2d854d7
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Mon Jun 22 17:09:11 2015 +0200

    Refactor to support 3rd-order tensors.


>---------------------------------------------------------------

60a4d92e014789ebd760d79047e9bdf1e2d854d7
 .../activation_functions/rectifier_function.hpp    | 27 ++++++++++++++++------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp b/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
index 3eaa2de..ba14a14 100644
--- a/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
+++ b/src/mlpack/methods/ann/activation_functions/rectifier_function.hpp
@@ -52,16 +52,29 @@ class RectifierFunction
   }
 
   /**
-   * Computes the rectifier function.
+   * Computes the rectifier function using a dense matrix as input.
+   *
+   * @param x Input data.
+   * @param y The resulting output activation.
+   */
+  template<typename eT>
+  static void fn(const arma::Mat<eT>& x, arma::Mat<eT>& y)
+  {
+    y = arma::max(arma::zeros<arma::Mat<eT> >(x.n_rows, x.n_cols), x);
+  }
+
+  /**
+   * Computes the rectifier function using a 3rd-order tensor as input.
    *
    * @param x Input data.
    * @param y The resulting output activation.
    */
-  template<typename InputVecType, typename OutputVecType>
-  static void fn(const InputVecType& x, OutputVecType& y)
+  template<typename eT>
+  static void fn(const arma::Cube<eT>& x, arma::Cube<eT>& y)
   {
     y = x;
-    y = arma::max(arma::zeros<OutputVecType>(x.n_elem), x);
+    for (size_t s = 0; s < x.n_slices; s++)
+      fn(x.slice(s), y.slice(s));
   }
 
   /**
@@ -72,7 +85,7 @@ class RectifierFunction
    */
   static double deriv(const double y)
   {
-    return y > 0 ? 1 : 0;
+    return y > 0;
   }
 
   /**
@@ -81,8 +94,8 @@ class RectifierFunction
    * @param y Input activations.
    * @param x The resulting derivatives.
    */
-  template<typename InputVecType, typename OutputVecType>
-  static void deriv(const InputVecType& y, OutputVecType& x)
+  template<typename InputType, typename OutputType>
+  static void deriv(const InputType& y, OutputType& x)
   {
     x = y;
     x.transform( [](double y) { return deriv(y); } );



More information about the mlpack-git mailing list