[mlpack-git] master: Use the rectifier function for the whole test and decrease the overall error tolerance. (e58937c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Jul 5 08:53:32 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/d9e984e1c608679171ad52e8522916703c7b331f...267bf1f0ace881bea4a38bf1156cc9f503930f09
>---------------------------------------------------------------
commit e58937c16940ba93828e873583dd89f2b4e7a242
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Jul 4 14:09:08 2015 +0200
Use the rectifier function for the whole test and decrease the overall error tolerance.
>---------------------------------------------------------------
e58937c16940ba93828e873583dd89f2b4e7a242
src/mlpack/tests/convolutional_network_test.cpp | 80 +++++++++++++++----------
1 file changed, 48 insertions(+), 32 deletions(-)
diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index 578fb8e..8851318 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -6,7 +6,7 @@
*/
#include <mlpack/core.hpp>
-#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
+#include <mlpack/methods/ann/activation_functions/rectifier_function.hpp>
#include <mlpack/methods/ann/connections/full_connection.hpp>
#include <mlpack/methods/ann/connections/bias_connection.hpp>
@@ -14,12 +14,15 @@
#include <mlpack/methods/ann/connections/pooling_connection.hpp>
#include <mlpack/methods/ann/layer/neuron_layer.hpp>
+#include <mlpack/methods/ann/layer/softmax_layer.hpp>
#include <mlpack/methods/ann/layer/bias_layer.hpp>
-#include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
+#include <mlpack/methods/ann/layer/one_hot_layer.hpp>
#include <mlpack/methods/ann/cnn.hpp>
#include <mlpack/methods/ann/trainer/trainer.hpp>
#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
+#include <mlpack/methods/ann/optimizer/ada_delta.hpp>
+#include <mlpack/methods/ann/init_rules/zero_init.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -51,11 +54,11 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
{
if (i < nPoints / 2)
{
- Y.col(i)(0) = 1;
+ Y.col(i)(1) = 1;
}
else
{
- Y.col(i)(1) = 1;
+ Y.col(i)(8) = 1;
}
}
@@ -77,16 +80,19 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
* +---+ +---+ +---+ +---+ +---+ +---+
*/
- NeuronLayer<LogisticFunction, arma::cube> inputLayer(28, 28, 1);
+ NeuronLayer<RectifierFunction, arma::cube> inputLayer(28, 28, 1);
- ConvLayer<LogisticFunction> convLayer0(24, 24, inputLayer.LayerSlices(), 6);
+ ConvLayer<RectifierFunction> convLayer0(24, 24, inputLayer.LayerSlices(), 6);
ConvConnection<decltype(inputLayer),
- decltype(convLayer0)>
+ decltype(convLayer0),
+ mlpack::ann::AdaDelta>
con1(inputLayer, convLayer0, 5);
BiasLayer<> biasLayer0(6);
BiasConnection<decltype(biasLayer0),
- decltype(convLayer0)>
+ decltype(convLayer0),
+ mlpack::ann::AdaDelta,
+ mlpack::ann::ZeroInitialization>
con1Bias(biasLayer0, convLayer0);
PoolingLayer<> poolingLayer0(12, 12, inputLayer.LayerSlices(), 6);
@@ -94,14 +100,17 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
decltype(poolingLayer0)>
con2(convLayer0, poolingLayer0);
- ConvLayer<LogisticFunction> convLayer1(8, 8, inputLayer.LayerSlices(), 12);
+ ConvLayer<RectifierFunction> convLayer1(8, 8, inputLayer.LayerSlices(), 12);
ConvConnection<decltype(poolingLayer0),
- decltype(convLayer1)>
+ decltype(convLayer1),
+ mlpack::ann::AdaDelta>
con3(poolingLayer0, convLayer1, 5);
BiasLayer<> biasLayer3(12);
BiasConnection<decltype(biasLayer3),
- decltype(convLayer1)>
+ decltype(convLayer1),
+ mlpack::ann::AdaDelta,
+ mlpack::ann::ZeroInitialization>
con3Bias(biasLayer3, convLayer1);
PoolingLayer<> poolingLayer1(4, 4, inputLayer.LayerSlices(), 12);
@@ -109,25 +118,28 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
decltype(poolingLayer1)>
con4(convLayer1, poolingLayer1);
- NeuronLayer<LogisticFunction, arma::mat> outputLayer(10,
- inputLayer.LayerSlices());
+ SoftmaxLayer<arma::mat> outputLayer(10,
+ inputLayer.LayerSlices());
FullConnection<decltype(poolingLayer1),
- decltype(outputLayer)>
+ decltype(outputLayer),
+ mlpack::ann::AdaDelta>
con5(poolingLayer1, outputLayer);
BiasLayer<> biasLayer1(1);
FullConnection<decltype(biasLayer1),
- decltype(outputLayer)>
+ decltype(outputLayer),
+ mlpack::ann::AdaDelta,
+ mlpack::ann::ZeroInitialization>
con5Bias(biasLayer1, outputLayer);
- BinaryClassificationLayer finalOutputLayer;
+ OneHotLayer finalOutputLayer;
auto module0 = std::tie(con1, con1Bias);
auto module1 = std::tie(con2);
- auto module2 = std::tie(con3);
+ auto module2 = std::tie(con3, con3Bias);
auto module3 = std::tie(con4);
- auto module4 = std::tie(con5);
+ auto module4 = std::tie(con5, con5Bias);
auto modules = std::tie(module0, module1, module2, module3, module4);
CNN<decltype(modules), decltype(finalOutputLayer),
@@ -135,7 +147,8 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
Trainer<decltype(net)> trainer(net, 1);
- for (size_t j = 0; j < 40; ++j)
+ size_t error = 0;
+ for (size_t j = 0; j < 30; ++j)
{
arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(0,
499, 500);
@@ -151,25 +164,28 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
trainer.Train(input, labels, input, labels);
}
- }
- size_t error = 0;
- for (size_t i = 0; i < 500; i++)
- {
- arma::cube input = arma::cube(X.colptr(i), 28, 28, 1);
- arma::mat labels = Y.col(i);
+ error = 0;
+ for (size_t p = 0; p < 500; p++)
+ {
+ arma::cube input = arma::cube(X.colptr(p), 28, 28, 1);
+ arma::mat labels = Y.col(p);
- arma::mat prediction;
- net.Predict(input, prediction);
+ arma::mat prediction;
+ net.Predict(input, prediction);
- bool b = arma::all(arma::abs(
- arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
+ bool b = arma::all(arma::abs(
+ arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
+
+ if (!b)
+ error++;
+ }
- if (!b)
- error++;
+ if (error <= 10)
+ break;
}
- BOOST_REQUIRE_LE(error, 90);
+ BOOST_REQUIRE_LE(error, 10);
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list