[mlpack-git] master: Slight refactorization for simplicity. No need to specify additional template parameter. (ad5c77b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Jun 4 04:47:08 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2f479f388ee3d34e4a20535c3662b1921a4c6c06...7fb32130bd683cf03a853ea2bc6960e80d625955

>---------------------------------------------------------------

commit ad5c77b97aead24ee389ca2cb2a8f1a29a356ebc
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 2 22:14:33 2015 +0200

    Slight refactorization for simplicity. No need to specify additional template parameter.


>---------------------------------------------------------------

ad5c77b97aead24ee389ca2cb2a8f1a29a356ebc
 .../ann/layer/binary_classification_layer.hpp      | 25 +++++++++-------------
 src/mlpack/tests/feedforward_network_test.cpp      | 16 +++++++-------
 src/mlpack/tests/recurrent_network_test.cpp        | 16 +++++++-------
 3 files changed, 26 insertions(+), 31 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
index ecd6064..04537f4 100644
--- a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
@@ -5,8 +5,8 @@
  * Definition of the BinaryClassificationLayer class, which implements a
  * binary class classification layer that can be used as output layer.
  */
-#ifndef __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
-#define __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
+#ifndef __MLPACK_METHODS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
@@ -17,14 +17,7 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * An implementation of a binary classification layer that can be used as
  * output layer.
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
-template <
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
->
 class BinaryClassificationLayer
 {
  public:
@@ -45,9 +38,10 @@ class BinaryClassificationLayer
    * @param error The calculated error with respect to the input activation and
    * the given target.
    */
-  void CalculateError(const VecType& inputActivations,
-                      const VecType& target,
-                      VecType& error)
+  template<typename DataType>
+  void CalculateError(const DataType& inputActivations,
+                      const DataType& target,
+                      DataType& error)
   {
     error = inputActivations - target;
   }
@@ -58,7 +52,8 @@ class BinaryClassificationLayer
    * @param inputActivations Input data used to calculate the output class.
    * @param output Output class of the input activation.
    */
-  void OutputClass(const VecType& inputActivations, VecType& output)
+  template<typename DataType>
+  void OutputClass(const DataType& inputActivations, DataType& output)
   {
     output = inputActivations;
     output.transform( [](double value) { return (value > 0.5 ? 1 : 0); } );
@@ -66,8 +61,8 @@ class BinaryClassificationLayer
 }; // class BinaryClassificationLayer
 
 //! Layer traits for the binary class classification layer.
-template <typename MatType, typename VecType>
-class LayerTraits<BinaryClassificationLayer<MatType, VecType> >
+template <>
+class LayerTraits<BinaryClassificationLayer>
 {
  public:
   static const bool IsBinary = true;
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index 1fd8265..be58169 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -167,7 +167,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   BuildVanillaNetwork<RandomInitialization,
                       LogisticFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (trainData, trainLabels, testData, testLabels, 4, 500,
           0.1, 60, randInitA);
@@ -187,7 +187,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   BuildVanillaNetwork<RandomInitialization,
                       LogisticFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (dataset, labels, dataset, labels, 100, 100, 0.6, 10, randInitB);
 
@@ -195,7 +195,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   BuildVanillaNetwork<RandomInitialization,
                     TanhFunction,
                     SteepestDescent<>,
-                    BinaryClassificationLayer<>,
+                    BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
     (dataset, labels, dataset, labels, 10, 200, 0.6, 20, randInitB);
 }
@@ -219,7 +219,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   BuildVanillaNetwork<RandomInitialization,
                       LogisticFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
 
@@ -227,7 +227,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   BuildVanillaNetwork<RandomInitialization,
                       TanhFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
 
@@ -240,7 +240,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   BuildVanillaNetwork<RandomInitialization,
                     LogisticFunction,
                     SteepestDescent<>,
-                    BinaryClassificationLayer<>,
+                    BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
     (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
 
@@ -248,7 +248,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   BuildVanillaNetwork<RandomInitialization,
                       TanhFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
 }
@@ -371,7 +371,7 @@ BOOST_AUTO_TEST_CASE(NetworkDecreasingErrorTest)
   BuildNetworkOptimzer<RandomInitialization,
                        LogisticFunction,
                        SteepestDescent<>,
-                       BinaryClassificationLayer<>,
+                       BinaryClassificationLayer,
                        MeanSquaredErrorFunction>
       (dataset, labels, dataset, labels, 100, 50, randInitB);
 }
diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index 566a821..cebc5aa 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -110,7 +110,7 @@ BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
   NeuronLayer<LogisticFunction> hiddenLayer0(4);
   NeuronLayer<LogisticFunction> recurrentLayer0(hiddenLayer0.InputSize());
   NeuronLayer<LogisticFunction> hiddenLayer1(2);
-  BinaryClassificationLayer<> outputLayer;
+  BinaryClassificationLayer outputLayer;
 
   SteepestDescent< > conOptimizer0(inputLayer.InputSize(),
       hiddenLayer0.InputSize(), 1, 0);
@@ -327,7 +327,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
   CompareVanillaNetworks<RandomInitialization,
                       LogisticFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (input, labels, input, labels, 10, 10, randInit);
 
@@ -335,7 +335,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
   CompareVanillaNetworks<RandomInitialization,
                       IdentityFunction,
                       SteepestDescent<>,
-                      BinaryClassificationLayer<>,
+                      BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
       (input, labels, input, labels, 1, 1, randInit);
 
@@ -343,7 +343,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
   CompareVanillaNetworks<RandomInitialization,
                     RectifierFunction,
                     SteepestDescent<>,
-                    BinaryClassificationLayer<>,
+                    BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
     (input, labels, input, labels, 10, 10, randInit);
 
@@ -351,7 +351,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
   CompareVanillaNetworks<RandomInitialization,
                     SoftsignFunction,
                     SteepestDescent<>,
-                    BinaryClassificationLayer<>,
+                    BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
     (input, labels, input, labels, 10, 10, randInit);
 
@@ -359,7 +359,7 @@ BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
   CompareVanillaNetworks<RandomInitialization,
                     TanhFunction,
                     SteepestDescent<>,
-                    BinaryClassificationLayer<>,
+                    BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
     (input, labels, input, labels, 10, 10, randInit);
 }
@@ -576,7 +576,7 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
   NeuronLayer<LogisticFunction> inputLayer(7);
   NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
   NeuronLayer<LogisticFunction> hiddenLayer1(7);
-  BinaryClassificationLayer<> outputLayer;
+  BinaryClassificationLayer outputLayer;
 
   SteepestDescent< > conOptimizer0(inputLayer.OutputSize(),
       hiddenLayer0.InputSize(), 0.1);
@@ -811,7 +811,7 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
   NeuronLayer<LogisticFunction> inputLayer(10);
   NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
   NeuronLayer<LogisticFunction> hiddenLayer1(3);
-  BinaryClassificationLayer<> outputLayer;
+  BinaryClassificationLayer outputLayer;
 
   SteepestDescent< > conOptimizer0(inputLayer.OutputSize(),
       hiddenLayer0.InputSize(), 0.1);



More information about the mlpack-git mailing list