[mlpack-git] master: Simplify the training process. (c326043)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jul 9 14:54:49 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/baeb3885b27d3b7dd552c638c605e034b1388cad...4a97187bbba7ce8a6191b714949dd818ef0f37d2
>---------------------------------------------------------------
commit c326043d90bebaf8610c651d7390c6c4442d84fc
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Wed Jul 8 17:05:11 2015 +0200
Simplify the training process.
>---------------------------------------------------------------
c326043d90bebaf8610c651d7390c6c4442d84fc
src/mlpack/tests/convolutional_network_test.cpp | 104 +++++-------------------
1 file changed, 22 insertions(+), 82 deletions(-)
diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index 3c29031..206d8e1 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -144,50 +144,18 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
auto module4 = std::tie(con5, con5Bias);
auto modules = std::tie(module0, module1, module2, module3, module4);
- CNN<decltype(modules), decltype(finalOutputLayer),
- MeanSquaredErrorFunction> net(modules, finalOutputLayer);
+ CNN<decltype(modules), decltype(finalOutputLayer)>
+ net(modules, finalOutputLayer);
- Trainer<decltype(net)> trainer(net, 1);
+ Trainer<decltype(net)> trainer(net, 50, 1, 0.03);
- size_t error = 0;
- for (size_t j = 0; j < 30; ++j)
- {
- arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(0,
- 499, 500);
- index = arma::shuffle(index);
-
- for (size_t i = 0; i < 500; i++)
- {
- arma::cube input = arma::cube(28, 28, 1);
- input.slice(0) = arma::mat(X.colptr(index(i)), 28, 28);
-
- arma::mat labels = arma::mat(10, 1);
- labels.col(0) = Y.col(index(i));
-
- trainer.Train(input, labels, input, labels);
- }
-
- error = 0;
- for (size_t p = 0; p < 500; p++)
- {
- arma::cube input = arma::cube(X.colptr(p), 28, 28, 1);
- arma::mat labels = Y.col(p);
-
- arma::mat prediction;
- net.Predict(input, prediction);
-
- bool b = arma::all(arma::abs(
- arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
-
- if (!b)
- error++;
- }
+ arma::cube input = arma::cube(28, 28, nPoints);
+ for (size_t i = 0; i < nPoints; i++)
+ input.slice(i) = arma::mat(X.colptr(i), 28, 28);
- if (error <= 10)
- break;
- }
+ trainer.Train(input, Y, input, Y);
- BOOST_REQUIRE_LE(error, 10);
+ BOOST_REQUIRE_LE(trainer.ValidationError(), 0.03);
}
/**
@@ -254,7 +222,9 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
DropoutLayer<arma::cube> dropoutLayer0(24, 24, inputLayer.LayerSlices(), 6);
IdentityConnection<decltype(convLayer0),
- decltype(dropoutLayer0)>
+ decltype(dropoutLayer0),
+ mlpack::ann::AdaDelta,
+ arma::cube>
con1Dropout(convLayer0, dropoutLayer0);
PoolingLayer<> poolingLayer0(12, 12, inputLayer.LayerSlices(), 6);
@@ -277,7 +247,9 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
DropoutLayer<arma::cube> dropoutLayer3(8, 8, inputLayer.LayerSlices(), 12);
IdentityConnection<decltype(convLayer1),
- decltype(dropoutLayer3)>
+ decltype(dropoutLayer3),
+ mlpack::ann::AdaDelta,
+ arma::cube>
con3Dropout(convLayer1, dropoutLayer3);
PoolingLayer<> poolingLayer1(4, 4, inputLayer.LayerSlices(), 12);
@@ -312,50 +284,18 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
auto modules = std::tie(module0, module0Dropout, module1, module2,
module2Dropout, module3, module4);
- CNN<decltype(modules), decltype(finalOutputLayer),
- MeanSquaredErrorFunction> net(modules, finalOutputLayer);
+ CNN<decltype(modules), decltype(finalOutputLayer)>
+ net(modules, finalOutputLayer);
- Trainer<decltype(net)> trainer(net, 1);
+ Trainer<decltype(net)> trainer(net, 50, 1, 0.03);
- size_t error = 0;
- for (size_t j = 0; j < 30; ++j)
- {
- arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(0,
- 499, 500);
- index = arma::shuffle(index);
-
- for (size_t i = 0; i < 500; i++)
- {
- arma::cube input = arma::cube(28, 28, 1);
- input.slice(0) = arma::mat(X.colptr(index(i)), 28, 28);
-
- arma::mat labels = arma::mat(10, 1);
- labels.col(0) = Y.col(index(i));
-
- trainer.Train(input, labels, input, labels);
- }
-
- error = 0;
- for (size_t p = 0; p < 500; p++)
- {
- arma::cube input = arma::cube(X.colptr(p), 28, 28, 1);
- arma::mat labels = Y.col(p);
-
- arma::mat prediction;
- net.Predict(input, prediction);
-
- bool b = arma::all(arma::abs(
- arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
-
- if (!b)
- error++;
- }
+ arma::cube input = arma::cube(28, 28, nPoints);
+ for (size_t i = 0; i < nPoints; i++)
+ input.slice(i) = arma::mat(X.colptr(i), 28, 28);
- if (error <= 10)
- break;
- }
+ trainer.Train(input, Y, input, Y);
- BOOST_REQUIRE_LE(error, 10);
+ BOOST_REQUIRE_LE(trainer.ValidationError(), 0.03);
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list