[mlpack-git] master: Simplify the training process. (c326043)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Jul 9 14:54:49 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/baeb3885b27d3b7dd552c638c605e034b1388cad...4a97187bbba7ce8a6191b714949dd818ef0f37d2

>---------------------------------------------------------------

commit c326043d90bebaf8610c651d7390c6c4442d84fc
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Wed Jul 8 17:05:11 2015 +0200

    Simplify the training process.


>---------------------------------------------------------------

c326043d90bebaf8610c651d7390c6c4442d84fc
 src/mlpack/tests/convolutional_network_test.cpp | 104 +++++-------------------
 1 file changed, 22 insertions(+), 82 deletions(-)

diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index 3c29031..206d8e1 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -144,50 +144,18 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   auto module4 = std::tie(con5, con5Bias);
   auto modules = std::tie(module0, module1, module2, module3, module4);
 
-  CNN<decltype(modules), decltype(finalOutputLayer),
-      MeanSquaredErrorFunction> net(modules, finalOutputLayer);
+  CNN<decltype(modules), decltype(finalOutputLayer)>
+      net(modules, finalOutputLayer);
 
-  Trainer<decltype(net)> trainer(net, 1);
+  Trainer<decltype(net)> trainer(net, 50, 1, 0.03);
 
-  size_t error = 0;
-  for (size_t j = 0; j < 30; ++j)
-  {
-    arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(0,
-        499, 500);
-    index = arma::shuffle(index);
-
-    for (size_t i = 0; i < 500; i++)
-    {
-      arma::cube input = arma::cube(28, 28, 1);
-      input.slice(0) = arma::mat(X.colptr(index(i)), 28, 28);
-
-      arma::mat labels = arma::mat(10, 1);
-      labels.col(0) = Y.col(index(i));
-
-      trainer.Train(input, labels, input, labels);
-    }
-
-    error = 0;
-    for (size_t p = 0; p < 500; p++)
-    {
-      arma::cube input = arma::cube(X.colptr(p), 28, 28, 1);
-      arma::mat labels = Y.col(p);
-
-      arma::mat prediction;
-      net.Predict(input, prediction);
-
-      bool b = arma::all(arma::abs(
-          arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
-
-      if (!b)
-        error++;
-    }
+  arma::cube input = arma::cube(28, 28, nPoints);
+  for (size_t i = 0; i < nPoints; i++)
+    input.slice(i) = arma::mat(X.colptr(i), 28, 28);
 
-    if (error <= 10)
-      break;
-  }
+  trainer.Train(input, Y, input, Y);
 
-  BOOST_REQUIRE_LE(error, 10);
+  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.03);
 }
 
 /**
@@ -254,7 +222,9 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
 
   DropoutLayer<arma::cube> dropoutLayer0(24, 24, inputLayer.LayerSlices(), 6);
   IdentityConnection<decltype(convLayer0),
-                     decltype(dropoutLayer0)>
+                     decltype(dropoutLayer0),
+                     mlpack::ann::AdaDelta,
+                     arma::cube>
       con1Dropout(convLayer0, dropoutLayer0);
 
   PoolingLayer<> poolingLayer0(12, 12, inputLayer.LayerSlices(), 6);
@@ -277,7 +247,9 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
 
   DropoutLayer<arma::cube> dropoutLayer3(8, 8, inputLayer.LayerSlices(), 12);
   IdentityConnection<decltype(convLayer1),
-                     decltype(dropoutLayer3)>
+                     decltype(dropoutLayer3),
+                     mlpack::ann::AdaDelta,
+                     arma::cube>
       con3Dropout(convLayer1, dropoutLayer3);
 
   PoolingLayer<> poolingLayer1(4, 4, inputLayer.LayerSlices(), 12);
@@ -312,50 +284,18 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
   auto modules = std::tie(module0, module0Dropout, module1, module2,
       module2Dropout, module3, module4);
 
-  CNN<decltype(modules), decltype(finalOutputLayer),
-      MeanSquaredErrorFunction> net(modules, finalOutputLayer);
+  CNN<decltype(modules), decltype(finalOutputLayer)>
+      net(modules, finalOutputLayer);
 
-  Trainer<decltype(net)> trainer(net, 1);
+  Trainer<decltype(net)> trainer(net, 50, 1, 0.03);
 
-  size_t error = 0;
-  for (size_t j = 0; j < 30; ++j)
-  {
-    arma::Col<size_t> index = arma::linspace<arma::Col<size_t> >(0,
-        499, 500);
-    index = arma::shuffle(index);
-
-    for (size_t i = 0; i < 500; i++)
-    {
-      arma::cube input = arma::cube(28, 28, 1);
-      input.slice(0) = arma::mat(X.colptr(index(i)), 28, 28);
-
-      arma::mat labels = arma::mat(10, 1);
-      labels.col(0) = Y.col(index(i));
-
-      trainer.Train(input, labels, input, labels);
-    }
-
-    error = 0;
-    for (size_t p = 0; p < 500; p++)
-    {
-      arma::cube input = arma::cube(X.colptr(p), 28, 28, 1);
-      arma::mat labels = Y.col(p);
-
-      arma::mat prediction;
-      net.Predict(input, prediction);
-
-      bool b = arma::all(arma::abs(
-          arma::vectorise(prediction) - arma::vectorise(labels)) < 0.1);
-
-      if (!b)
-        error++;
-    }
+  arma::cube input = arma::cube(28, 28, nPoints);
+  for (size_t i = 0; i < nPoints; i++)
+    input.slice(i) = arma::mat(X.colptr(i), 28, 28);
 
-    if (error <= 10)
-      break;
-  }
+  trainer.Train(input, Y, input, Y);
 
-  BOOST_REQUIRE_LE(error, 10);
+  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.03);
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list