[mlpack-git] master: Slight refactorization for simplicity. No need to specify additional template parameter. (ad5c77b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jun 4 04:47:08 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/2f479f388ee3d34e4a20535c3662b1921a4c6c06...7fb32130bd683cf03a853ea2bc6960e80d625955
>---------------------------------------------------------------
commit ad5c77b97aead24ee389ca2cb2a8f1a29a356ebc
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 2 22:14:33 2015 +0200
Slight refactorization for simplicity. No need to specify additional template parameter.
>---------------------------------------------------------------
ad5c77b97aead24ee389ca2cb2a8f1a29a356ebc
.../ann/layer/binary_classification_layer.hpp | 25 +++++++++-------------
src/mlpack/tests/feedforward_network_test.cpp | 16 +++++++-------
src/mlpack/tests/recurrent_network_test.cpp | 16 +++++++-------
3 files changed, 26 insertions(+), 31 deletions(-)
diff --git a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
index ecd6064..04537f4 100644
--- a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
@@ -5,8 +5,8 @@
* Definition of the BinaryClassificationLayer class, which implements a
* binary class classification layer that can be used as output layer.
*/
-#ifndef __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
-#define __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
+#ifndef __MLPACK_METHODS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/layer/layer_traits.hpp>
@@ -17,14 +17,7 @@ namespace ann /** Artificial Neural Network. */ {
/**
* An implementation of a binary classification layer that can be used as
* output layer.
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
-template <
- typename MatType = arma::mat,
- typename VecType = arma::colvec
->
class BinaryClassificationLayer
{
public:
@@ -45,9 +38,10 @@ class BinaryClassificationLayer
* @param error The calculated error with respect to the input activation and
* the given target.
*/
- void CalculateError(const VecType& inputActivations,
- const VecType& target,
- VecType& error)
+ template<typename DataType>
+ void CalculateError(const DataType& inputActivations,
+ const DataType& target,
+ DataType& error)
{
error = inputActivations - target;
}
@@ -58,7 +52,8 @@ class BinaryClassificationLayer
* @param inputActivations Input data used to calculate the output class.
* @param output Output class of the input activation.
*/
- void OutputClass(const VecType& inputActivations, VecType& output)
+ template<typename DataType>
+ void OutputClass(const DataType& inputActivations, DataType& output)
{
output = inputActivations;
output.transform( [](double value) { return (value > 0.5 ? 1 : 0); } );
@@ -66,8 +61,8 @@ class BinaryClassificationLayer
}; // class BinaryClassificationLayer
//! Layer traits for the binary class classification layer.
-template <typename MatType, typename VecType>
-class LayerTraits<BinaryClassificationLayer<MatType, VecType> >
+template <>
+class LayerTraits<BinaryClassificationLayer>
{
public:
static const bool IsBinary = true;
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index 1fd8265..be58169 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -167,7 +167,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
BuildVanillaNetwork<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(trainData, trainLabels, testData, testLabels, 4, 500,
0.1, 60, randInitA);
@@ -187,7 +187,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
BuildVanillaNetwork<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(dataset, labels, dataset, labels, 100, 100, 0.6, 10, randInitB);
@@ -195,7 +195,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
BuildVanillaNetwork<RandomInitialization,
TanhFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(dataset, labels, dataset, labels, 10, 200, 0.6, 20, randInitB);
}
@@ -219,7 +219,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
BuildVanillaNetwork<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 4, 0, 0, 0.01, randInit);
@@ -227,7 +227,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
BuildVanillaNetwork<RandomInitialization,
TanhFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 4, 0, 0, 0.01, randInit);
@@ -240,7 +240,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
BuildVanillaNetwork<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 4, 0, 0, 0.01, randInit);
@@ -248,7 +248,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
BuildVanillaNetwork<RandomInitialization,
TanhFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 4, 0, 0, 0.01, randInit);
}
@@ -371,7 +371,7 @@ BOOST_AUTO_TEST_CASE(NetworkDecreasingErrorTest)
BuildNetworkOptimzer<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(dataset, labels, dataset, labels, 100, 50, randInitB);
}
diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index 566a821..cebc5aa 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -110,7 +110,7 @@ BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
NeuronLayer<LogisticFunction> hiddenLayer0(4);
NeuronLayer<LogisticFunction> recurrentLayer0(hiddenLayer0.InputSize());
NeuronLayer<LogisticFunction> hiddenLayer1(2);
- BinaryClassificationLayer<> outputLayer;
+ BinaryClassificationLayer outputLayer;
SteepestDescent< > conOptimizer0(inputLayer.InputSize(),
hiddenLayer0.InputSize(), 1, 0);
@@ -327,7 +327,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
CompareVanillaNetworks<RandomInitialization,
LogisticFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 10, 10, randInit);
@@ -335,7 +335,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
CompareVanillaNetworks<RandomInitialization,
IdentityFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 1, 1, randInit);
@@ -343,7 +343,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
CompareVanillaNetworks<RandomInitialization,
RectifierFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 10, 10, randInit);
@@ -351,7 +351,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
CompareVanillaNetworks<RandomInitialization,
SoftsignFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 10, 10, randInit);
@@ -359,7 +359,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
CompareVanillaNetworks<RandomInitialization,
TanhFunction,
SteepestDescent<>,
- BinaryClassificationLayer<>,
+ BinaryClassificationLayer,
MeanSquaredErrorFunction>
(input, labels, input, labels, 10, 10, randInit);
}
@@ -576,7 +576,7 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
NeuronLayer<LogisticFunction> inputLayer(7);
NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
NeuronLayer<LogisticFunction> hiddenLayer1(7);
- BinaryClassificationLayer<> outputLayer;
+ BinaryClassificationLayer outputLayer;
SteepestDescent< > conOptimizer0(inputLayer.OutputSize(),
hiddenLayer0.InputSize(), 0.1);
@@ -811,7 +811,7 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
NeuronLayer<LogisticFunction> inputLayer(10);
NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
NeuronLayer<LogisticFunction> hiddenLayer1(3);
- BinaryClassificationLayer<> outputLayer;
+ BinaryClassificationLayer outputLayer;
SteepestDescent< > conOptimizer0(inputLayer.OutputSize(),
hiddenLayer0.InputSize(), 0.1);
More information about the mlpack-git
mailing list