[mlpack-git] master: Use bias connection and decrease number of epochs. (37c89a5)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:36 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1

>---------------------------------------------------------------

commit 37c89a53363908c7570dcbef0962de67525d3c3c
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Jun 12 22:55:06 2015 +0200

    Use bias connection and decrease number of epochs.


>---------------------------------------------------------------

37c89a53363908c7570dcbef0962de67525d3c3c
 src/mlpack/tests/convolutional_network_test.cpp | 28 +++++++++++++++++++------
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index b6b721e..da135a1 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -9,10 +9,12 @@
 #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
 
 #include <mlpack/methods/ann/connections/full_connection.hpp>
+#include <mlpack/methods/ann/connections/bias_connection.hpp>
 #include <mlpack/methods/ann/connections/conv_connection.hpp>
 #include <mlpack/methods/ann/connections/pooling_connection.hpp>
 
 #include <mlpack/methods/ann/layer/neuron_layer.hpp>
+#include <mlpack/methods/ann/layer/bias_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 
 #include <mlpack/methods/ann/cnn.hpp>
@@ -75,7 +77,6 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
    * +---+        +---+        +---+        +---+        +---+    +---+
    */
 
-
   NeuronLayer<LogisticFunction, arma::cube> inputLayer(28, 28, 1);
 
   ConvLayer<LogisticFunction> convLayer0(24, 24, inputLayer.LayerSlices(), 6);
@@ -83,6 +84,11 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
                  decltype(convLayer0)>
       con1(inputLayer, convLayer0, 5);
 
+  BiasLayer<> biasLayer0(6);
+  BiasConnection<decltype(biasLayer0),
+                 decltype(convLayer0)>
+  con1Bias(biasLayer0, convLayer0);
+
   con1.Weights().slice(0) = arma::mat(
       "-0.0307   -0.1510   -0.0299    0.0631    0.1114;"
       "0.0816   -0.1162    0.0686   -0.0306    0.1734;"
@@ -136,6 +142,11 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
                  decltype(convLayer1)>
       con3(poolingLayer0, convLayer1, 5);
 
+  BiasLayer<> biasLayer3(12);
+  BiasConnection<decltype(biasLayer3),
+                 decltype(convLayer1)>
+  con3Bias(biasLayer3, convLayer1);
+
   PoolingLayer<> poolingLayer1(4, 4, inputLayer.LayerSlices(), 12);
   PoolingConnection<decltype(convLayer1),
                     decltype(poolingLayer1)>
@@ -147,13 +158,19 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
                  decltype(outputLayer)>
     con5(poolingLayer1, outputLayer);
 
+  BiasLayer<> biasLayer1(1);
+  FullConnection<decltype(biasLayer1),
+                 decltype(outputLayer)>
+    con5Bias(biasLayer1, outputLayer);
+
+
   BinaryClassificationLayer finalOutputLayer;
 
-  auto module0 = std::tie(con1);
+  auto module0 = std::tie(con1, con1Bias);
   auto module1 = std::tie(con2);
-  auto module2 = std::tie(con3);
+  auto module2 = std::tie(con3, con3Bias);
   auto module3 = std::tie(con4);
-  auto module4 = std::tie(con5);
+  auto module4 = std::tie(con5, con5Bias);
   auto modules = std::tie(module0, module1, module2, module3, module4);
 
   CNN<decltype(modules), decltype(finalOutputLayer),
@@ -161,7 +178,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
 
   Trainer<decltype(net)> trainer(net, 1);
 
-  for (size_t j = 0; j < 300; ++j)
+  for (size_t j = 0; j < 40; ++j)
   {
     arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(200,
         299, 300);
@@ -189,7 +206,6 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
     arma::mat prediction;
     net.Predict(input, prediction);
 
-
     bool b = arma::all(arma::abs(
         arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
 



More information about the mlpack-git mailing list