[mlpack-git] master: Make the layer independent regarding the datatype. (90163a1)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:23 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d
>---------------------------------------------------------------
commit 90163a19f36e548d9098b67f053982a11b7e2496
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 23 15:05:15 2015 +0200
Make the layer independent regarding the datatype.
>---------------------------------------------------------------
90163a19f36e548d9098b67f053982a11b7e2496
src/mlpack/methods/ann/layer/one_hot_layer.hpp | 28 ++++++++++----------------
1 file changed, 11 insertions(+), 17 deletions(-)
diff --git a/src/mlpack/methods/ann/layer/one_hot_layer.hpp b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
index f3dd76f..93181f8 100644
--- a/src/mlpack/methods/ann/layer/one_hot_layer.hpp
+++ b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
@@ -17,14 +17,7 @@ namespace ann /** Artificial Neural Network. */ {
/**
* An implementation of a one hot classification layer that can be used as
* output layer.
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
-template <
- typename MatType = arma::mat,
- typename VecType = arma::colvec
->
class OneHotLayer
{
public:
@@ -45,9 +38,10 @@ class OneHotLayer
* @param error The calculated error with respect to the input activation and
* the given target.
*/
- void calculateError(const VecType& inputActivations,
- const VecType& target,
- VecType& error)
+ template<typename DataType>
+ void CalculateError(const DataType& inputActivations,
+ const DataType& target,
+ DataType& error)
{
error = inputActivations - target;
}
@@ -58,9 +52,12 @@ class OneHotLayer
* @param inputActivations Input data used to calculate the output class.
* @param output Output class of the input activation.
*/
- void outputClass(const VecType& inputActivations, VecType& output)
+ template<typename DataType>
+ void OutputClass(const DataType& inputActivations, DataType& output)
{
- output = arma::zeros<VecType>(inputActivations.n_elem);
+ output = inputActivations;
+ output.zeros();
+
arma::uword maxIndex;
inputActivations.max(maxIndex);
output(maxIndex) = 1;
@@ -68,11 +65,8 @@ class OneHotLayer
}; // class OneHotLayer
//! Layer traits for the one-hot class classification layer.
-template <
- typename MatType,
- typename VecType
->
-class LayerTraits<OneHotLayer<MatType, VecType> >
+template <>
+class LayerTraits<OneHotLayer>
{
public:
static const bool IsBinary = true;
More information about the mlpack-git
mailing list