[mlpack-git] master: Make the layer independent regarding the datatype. (90163a1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:23 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d

>---------------------------------------------------------------

commit 90163a19f36e548d9098b67f053982a11b7e2496
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 23 15:05:15 2015 +0200

    Make the layer independent regarding the datatype.


>---------------------------------------------------------------

90163a19f36e548d9098b67f053982a11b7e2496
 src/mlpack/methods/ann/layer/one_hot_layer.hpp | 28 ++++++++++----------------
 1 file changed, 11 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/one_hot_layer.hpp b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
index f3dd76f..93181f8 100644
--- a/src/mlpack/methods/ann/layer/one_hot_layer.hpp
+++ b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
@@ -17,14 +17,7 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * An implementation of a one hot classification layer that can be used as
  * output layer.
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
-template <
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
->
 class OneHotLayer
 {
  public:
@@ -45,9 +38,10 @@ class OneHotLayer
    * @param error The calculated error with respect to the input activation and
    * the given target.
    */
-  void calculateError(const VecType& inputActivations,
-                      const VecType& target,
-                      VecType& error)
+  template<typename DataType>
+  void CalculateError(const DataType& inputActivations,
+                      const DataType& target,
+                      DataType& error)
   {
     error = inputActivations - target;
   }
@@ -58,9 +52,12 @@ class OneHotLayer
    * @param inputActivations Input data used to calculate the output class.
    * @param output Output class of the input activation.
    */
-  void outputClass(const VecType& inputActivations, VecType& output)
+  template<typename DataType>
+  void OutputClass(const DataType& inputActivations, DataType& output)
   {
-    output = arma::zeros<VecType>(inputActivations.n_elem);
+    output = inputActivations;
+    output.zeros();
+
     arma::uword maxIndex;
     inputActivations.max(maxIndex);
     output(maxIndex) = 1;
@@ -68,11 +65,8 @@ class OneHotLayer
 }; // class OneHotLayer
 
 //! Layer traits for the one-hot class classification layer.
-template <
-    typename MatType,
-    typename VecType
->
-class LayerTraits<OneHotLayer<MatType, VecType> >
+template <>
+class LayerTraits<OneHotLayer>
 {
  public:
   static const bool IsBinary = true;



More information about the mlpack-git mailing list