[mlpack-git] master: Adjust the network test; Use the simplified layer structure. (0aeb767)

gitdub at mlpack.org gitdub at mlpack.org
Fri Feb 19 09:01:37 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f6dd2f7a9752a7db8ec284a938b3e84a13d0bfb2...6205f3e0b62b56452b2a4afc4da24fce5b21e72f

>---------------------------------------------------------------

commit 0aeb767485ad69fa199d5a601fd34ebd1e1b4b67
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Fri Feb 19 15:01:37 2016 +0100

    Adjust the network test; Use the simplified layer structure.


>---------------------------------------------------------------

0aeb767485ad69fa199d5a601fd34ebd1e1b4b67
 src/mlpack/tests/convolutional_network_test.cpp | 141 ++++-----------
 src/mlpack/tests/feedforward_network_test.cpp   | 217 ++++--------------------
 src/mlpack/tests/recurrent_network_test.cpp     |  76 ++++-----
 3 files changed, 103 insertions(+), 331 deletions(-)

diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index bd20116..c34f579 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -6,7 +6,6 @@
  */
 #include <mlpack/core.hpp>
 
-#include <mlpack/methods/ann/activation_functions/rectifier_function.hpp>
 #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
 
 #include <mlpack/methods/ann/layer/one_hot_layer.hpp>
@@ -16,19 +15,19 @@
 #include <mlpack/methods/ann/layer/bias_layer.hpp>
 #include <mlpack/methods/ann/layer/linear_layer.hpp>
 #include <mlpack/methods/ann/layer/base_layer.hpp>
-#include <mlpack/methods/ann/layer/dropout_layer.hpp>
 
+#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
+#include <mlpack/core/optimizers/rmsprop/rmsprop.hpp>
+
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
 #include <mlpack/methods/ann/cnn.hpp>
-#include <mlpack/methods/ann/trainer/trainer.hpp>
-#include <mlpack/methods/ann/optimizer/ada_delta.hpp>
-#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
-#include <mlpack/methods/ann/init_rules/zero_init.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
 using namespace mlpack;
 using namespace mlpack::ann;
+using namespace mlpack::optimization;
 
 
 BOOST_AUTO_TEST_SUITE(ConvolutionalNetworkTest);
@@ -87,21 +86,18 @@ void BuildVanillaNetwork()
    * +---+        +---+        +---+        +---+        +---+    +---+
    */
 
-  ConvLayer<RMSPROP> convLayer0(1, 8, 5, 5);
-  BiasLayer2D<RMSPROP, ZeroInitialization> biasLayer0(8);
-  BaseLayer2D<PerformanceFunction> baseLayer0;
+  ConvLayer<> convLayer0(1, 8, 5, 5);
+  BiasLayer2D<> biasLayer0(8);
+  BaseLayer2D<> baseLayer0;
   PoolingLayer<> poolingLayer0(2);
 
-
-
-
-  ConvLayer<RMSPROP> convLayer1(8, 12, 5, 5);
-  BiasLayer2D<RMSPROP, ZeroInitialization> biasLayer1(12);
-  BaseLayer2D<PerformanceFunction> baseLayer1;
+  ConvLayer<> convLayer1(8, 12, 5, 5);
+  BiasLayer2D<> biasLayer1(12);
+  BaseLayer2D<> baseLayer1;
   PoolingLayer<> poolingLayer1(2);
 
-  LinearMappingLayer<RMSPROP> linearLayer0(192, 10);
-  BiasLayer<RMSPROP> biasLayer2(10);
+  LinearMappingLayer<> linearLayer0(192, 10);
+  BiasLayer<> biasLayer2(10);
   SoftmaxLayer<> softmaxLayer0;
 
   OneHotLayer outputLayer;
@@ -110,115 +106,38 @@ void BuildVanillaNetwork()
                           convLayer1, biasLayer1, baseLayer1, poolingLayer1,
                           linearLayer0, biasLayer2, softmaxLayer0);
 
-  CNN<decltype(modules), decltype(outputLayer)>
-      net(modules, outputLayer);
+  CNN<decltype(modules), decltype(outputLayer),
+      RandomInitialization, MeanSquaredErrorFunction> net(modules, outputLayer);
+  biasLayer0.Weights().zeros();
+  biasLayer1.Weights().zeros();
 
-  Trainer<decltype(net)> trainer(net, 50, 1, 0.7);
-  trainer.Train(input, Y, input, Y);
+  RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8, 10 * input.n_slices, 0);
 
-  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7);
-}
-
-/**
- * Train the vanilla network on a larger dataset.
- */
-BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
-{
-  BuildVanillaNetwork<LogisticFunction>();
-}
-
-/**
- * Train and evaluate a vanilla network with the specified structure.
- */
-template<
-    typename PerformanceFunction
->
-void BuildVanillaDropoutNetwork()
-{
-  arma::mat X;
-  X.load("mnist_first250_training_4s_and_9s.arm");
+  net.Train(input, Y, opt);
 
-  // Normalize each point since these are images.
-  arma::uword nPoints = X.n_cols;
-  for (arma::uword i = 0; i < nPoints; i++)
-  {
-    X.col(i) /= norm(X.col(i), 2);
-  }
+  arma::mat prediction;
+  net.Predict(input, prediction);
 
-  // Build the target matrix.
-  arma::mat Y = arma::zeros<arma::mat>(10, nPoints);
+  size_t error = 0;
   for (size_t i = 0; i < nPoints; i++)
   {
-    if (i < nPoints / 2)
-    {
-      Y.col(i)(5) = 1;
-    }
-    else
+    if (arma::sum(arma::sum(
+        arma::abs(prediction.col(i) - Y.col(i)))) == 0)
     {
-      Y.col(i)(8) = 1;
+      error++;
     }
   }
 
-  arma::cube input = arma::cube(28, 28, nPoints);
-  for (size_t i = 0; i < nPoints; i++)
-    input.slice(i) = arma::mat(X.colptr(i), 28, 28);
-
-  /*
-   * Construct a convolutional neural network with a 28x28x1 input layer,
-   * 24x24x4 convolution layer, 24x24x4 dropout layer, 12x12x4 pooling layer,
-   * 8x8x8 convolution layer,8x8x8 Dropout Layer and a 4x4x12 pooling layer
-   * which is fully connected with the output layer. The network structure looks
-   * like:
-   *
-   * Input    Convolution  Dropout      Pooling     Convolution,     Output
-   * Layer    Layer        Layer        Layer       Dropout,         Layer
-   *                                                Pooling Layer
-   *          +---+        +---+        +---+
-   *          | +---+      | +---+      | +---+
-   * +---+    | | +---+    | | +---+    | | +---+                    +---+
-   * |   |    | | |   |    | | |   |    | | |   |                    |   |
-   * |   +--> +-+ |   +--> +-+ |   +--> +-+ |   +--> ............--> |   |
-   * |   |      +-+   |      +-+   |      +-+   |                    |   |
-   * +---+        +---+        +---+        +---+                    +---+
-   */
-
-  ConvLayer<AdaDelta> convLayer0(1, 4, 5, 5);
-  BiasLayer2D<AdaDelta, ZeroInitialization> biasLayer0(4);
-  DropoutLayer2D<> dropoutLayer0;
-  BaseLayer2D<PerformanceFunction> baseLayer0;
-  PoolingLayer<> poolingLayer0(2);
-
-  ConvLayer<AdaDelta> convLayer1(4, 8, 5, 5);
-  BiasLayer2D<AdaDelta, ZeroInitialization> biasLayer1(8);
-  BaseLayer2D<PerformanceFunction> baseLayer1;
-  PoolingLayer<> poolingLayer1(2);
-
-  LinearMappingLayer<AdaDelta> linearLayer0(128, 10);
-  BiasLayer<AdaDelta> biasLayer2(10);
-  SoftmaxLayer<> softmaxLayer0;
-
-  OneHotLayer outputLayer;
-
-  auto modules = std::tie(convLayer0, biasLayer0, dropoutLayer0, baseLayer0,
-                          poolingLayer0, convLayer1, biasLayer1, baseLayer1,
-                          poolingLayer1, linearLayer0, biasLayer2,
-                          softmaxLayer0);
-
-  CNN<decltype(modules), decltype(outputLayer)>
-      net(modules, outputLayer);
-
-  Trainer<decltype(net)> trainer(net, 50, 1, 0.7);
-  trainer.Train(input, Y, input, Y);
-
-  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7);
+  double classificationError = 1 - double(error) / nPoints;
+  BOOST_REQUIRE_LE(classificationError, 0.6);
 }
 
 /**
- * Train the network on a larger dataset using dropout.
+ * Train the vanilla network on a larger dataset.
  */
-BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
+BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
 {
-  BuildVanillaDropoutNetwork<RectifierFunction>();
+  BuildVanillaNetwork<LogisticFunction>();  
 }
 
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index 4c54e77..e8b6636 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -17,17 +17,16 @@
 #include <mlpack/methods/ann/layer/dropout_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 
-#include <mlpack/methods/ann/trainer/trainer.hpp>
 #include <mlpack/methods/ann/ffn.hpp>
 #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
-#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
+#include <mlpack/core/optimizers/rmsprop/rmsprop.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
 using namespace mlpack;
 using namespace mlpack::ann;
-
+using namespace mlpack::optimization;
 
 BOOST_AUTO_TEST_SUITE(FeedForwardNetworkTest);
 
@@ -46,8 +45,7 @@ void BuildVanillaNetwork(MatType& trainData,
                          MatType& testLabels,
                          const size_t hiddenLayerSize,
                          const size_t maxEpochs,
-                         const double classificationErrorThreshold,
-                         const double ValidationErrorThreshold)
+                         const double classificationErrorThreshold)
 {
   /*
    * Construct a feed forward network with trainData.n_rows input nodes,
@@ -84,30 +82,29 @@ void BuildVanillaNetwork(MatType& trainData,
   auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
                           hiddenLayer1, hiddenBiasLayer1, outputLayer);
 
-  FFN<decltype(modules), decltype(classOutputLayer), PerformanceFunctionType>
-      net(modules, classOutputLayer);
+  FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+      PerformanceFunctionType> net(modules, classOutputLayer);
+
+  RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+      maxEpochs * trainData.n_cols, 1e-18);
 
-  Trainer<decltype(net)> trainer(net, maxEpochs, 1, 0.01);
-  trainer.Train(trainData, trainLabels, testData, testLabels);
+  net.Train(trainData, trainLabels, opt);
 
   MatType prediction;
-  size_t error = 0;
+  net.Predict(testData, prediction);
 
+  size_t error = 0;
   for (size_t i = 0; i < testData.n_cols; i++)
   {
-    MatType predictionInput = testData.unsafe_col(i);
-    MatType targetOutput = testLabels.unsafe_col(i);
-
-    net.Predict(predictionInput, prediction);
-
-    if (arma::sum(arma::sum(arma::abs(prediction - targetOutput))) == 0)
+    if (arma::sum(arma::sum(
+        arma::abs(prediction.col(i) - testLabels.col(i)))) == 0)
+    {
       error++;
+    }
   }
 
   double classificationError = 1 - double(error) / testData.n_cols;
-
   BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
-  BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold);
 }
 
 /**
@@ -137,7 +134,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   BuildVanillaNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60);
+      (trainData, trainLabels, testData, testLabels, 8, 200, 0.1);
   
   dataset.load("mnist_first250_training_4s_and_9s.arm");
 
@@ -152,13 +149,13 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   BuildVanillaNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (dataset, labels, dataset, labels, 30, 100, 0.6, 10);
+      (dataset, labels, dataset, labels, 30, 30, 0.4);
 
   // Vanilla neural net with tanh activation function.
   BuildVanillaNetwork<TanhFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-    (dataset, labels, dataset, labels, 10, 200, 0.6, 20);
+    (dataset, labels, dataset, labels, 10, 30, 0.4);
 }
 
 /**
@@ -176,8 +173,7 @@ void BuildDropoutNetwork(MatType& trainData,
                          MatType& testLabels,
                          const size_t hiddenLayerSize,
                          const size_t maxEpochs,
-                         const double classificationErrorThreshold,
-                         const double ValidationErrorThreshold)
+                         const double classificationErrorThreshold)
 {
   /*
    * Construct a feed forward network with trainData.n_rows input nodes,
@@ -214,28 +210,29 @@ void BuildDropoutNetwork(MatType& trainData,
   auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0, dropoutLayer0,
                           hiddenLayer1, outputLayer);
 
-  FFN<decltype(modules), decltype(classOutputLayer), PerformanceFunctionType>
-      net(modules, classOutputLayer);
+  FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+      PerformanceFunctionType> net(modules, classOutputLayer);
 
-  Trainer<decltype(net)> trainer(net, maxEpochs, 1, 0.001);
-  trainer.Train(trainData, trainLabels, testData, testLabels);
+  RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+      maxEpochs * trainData.n_cols, 1e-18);
+
+  net.Train(trainData, trainLabels, opt);
 
   MatType prediction;
-  size_t error = 0;
+  net.Predict(testData, prediction);
 
+  size_t error = 0;
   for (size_t i = 0; i < testData.n_cols; i++)
   {
-    MatType input = testData.unsafe_col(i);
-    net.Predict(input, prediction);
-    if (arma::sum(arma::sum(arma::abs(
-      prediction - testLabels.unsafe_col(i)))) == 0)
+    if (arma::sum(arma::sum(
+        arma::abs(prediction.col(i) - testLabels.col(i)))) == 0)
+    {
       error++;
+    }
   }
 
   double classificationError = 1 - double(error) / testData.n_cols;
-
   BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
-  BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold);
 }
 
 /**
@@ -265,7 +262,7 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest)
   BuildDropoutNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (trainData, trainLabels, testData, testLabels, 4, 100, 0.1, 60);
+      (trainData, trainLabels, testData, testLabels, 4, 100, 0.1);
 
   dataset.load("mnist_first250_training_4s_and_9s.arm");
 
@@ -277,158 +274,16 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest)
   labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
 
   // Vanilla neural net with logistic activation function.
-  BuildVanillaNetwork<LogisticFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-      (dataset, labels, dataset, labels, 8, 100, 0.6, 10);
-
-  // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<TanhFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-    (dataset, labels, dataset, labels, 8, 100, 0.6, 20);
-}
-
-/**
- * Train the network until the validation error converge.
- */
-BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
-{
-  arma::mat input;
-  arma::mat labels;
-
-  // Test on a non-linearly separable dataset (XOR).
-  input << 0 << 1 << 1 << 0 << arma::endr
-        << 1 << 0 << 1 << 0 << arma::endr;
-  labels << 0 << 0 << 1 << 1;
-
-  // Vanilla neural net with logistic activation function.
-  BuildVanillaNetwork<LogisticFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 5000, 0, 0.01);
-
-  // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<TanhFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 5000, 0, 0.01);
-
-  // Test on a linearly separable dataset (AND).
-  input << 0 << 1 << 1 << 0 << arma::endr
-        << 1 << 0 << 1 << 0 << arma::endr;
-  labels << 0 << 0 << 1 << 0;
-
-  // vanilla neural net with sigmoid activation function.
-  BuildVanillaNetwork<LogisticFunction,
+  BuildDropoutNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-    (input, labels, input, labels, 4, 5000, 0, 0.01);
+      (dataset, labels, dataset, labels, 8, 30, 0.4);
 
   // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<TanhFunction,
+  BuildDropoutNetwork<TanhFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 5000, 0, 0.01);
-}
-
-/**
- * Train a vanilla network with the specified structure step by step and
- * evaluate the network.
- */
-template<
-    typename PerformanceFunction,
-    typename OutputLayerType,
-    typename PerformanceFunctionType,
-    typename MatType = arma::mat
->
-void BuildNetworkOptimzer(MatType& trainData,
-                          MatType& trainLabels,
-                          MatType& testData,
-                          MatType& testLabels,
-                          size_t hiddenLayerSize,
-                          size_t epochs)
-{
-  /*
-   * Construct a feed forward network with trainData.n_rows input nodes,
-   * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
-   * network structure looks like:
-   *
-   *  Input         Hidden        Output
-   *  Layer         Layer         Layer
-   * +-----+       +-----+       +-----+
-   * |     |       |     |       |     |
-   * |     +------>|     +------>|     |
-   * |     |     +>|     |     +>|     |
-   * +-----+     | +--+--+     | +-----+
-   *             |             |
-   *  Bias       |  Bias       |
-   *  Layer      |  Layer      |
-   * +-----+     | +-----+     |
-   * |     |     | |     |     |
-   * |     +-----+ |     +-----+
-   * |     |       |     |
-   * +-----+       +-----+
-   */
-
-  RandomInitialization randInit(0.5, 0.5);
-
-  LinearLayer<RMSPROP, RandomInitialization> inputLayer(trainData.n_rows,
-      hiddenLayerSize, randInit);
-  BiasLayer<RMSPROP, RandomInitialization> inputBiasLayer(hiddenLayerSize,
-      1, randInit);
-  BaseLayer<PerformanceFunction> inputBaseLayer;
-
-  LinearLayer<RMSPROP, RandomInitialization> hiddenLayer1(hiddenLayerSize,
-      trainLabels.n_rows, randInit);
-  BiasLayer<RMSPROP, RandomInitialization> hiddenBiasLayer1(trainLabels.n_rows,
-      1, randInit);
-  BaseLayer<PerformanceFunction> outputLayer;
-
-  OutputLayerType classOutputLayer;
-
-  auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
-                hiddenLayer1, hiddenBiasLayer1, outputLayer);
-
-  FFN<decltype(modules), OutputLayerType, PerformanceFunctionType>
-      net(modules, classOutputLayer);
-
-  Trainer<decltype(net)> trainer(net, epochs, 1, 0.0001, false);
-
-  double error = DBL_MAX;
-  for (size_t i = 0; i < 5; i++)
-  {
-    trainer.Train(trainData, trainLabels, testData, testLabels);
-    double validationError = trainer.ValidationError();
-
-    bool b = validationError < error || validationError == 0;
-    BOOST_REQUIRE_EQUAL(b, 1);
-
-    error = validationError;
-  }
-}
-
-/**
- * Train the network with different optimzer and check if the error decreases
- * over time.
- */
-BOOST_AUTO_TEST_CASE(NetworkDecreasingErrorTest)
-{
-  arma::mat dataset;
-  dataset.load("mnist_first250_training_4s_and_9s.arm");
-
-  // Normalize each point since these are images.
-  for (size_t i = 0; i < dataset.n_cols; ++i)
-    dataset.col(i) /= norm(dataset.col(i), 2);
-
-  arma::mat labels = arma::zeros(1, dataset.n_cols);
-  labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1) += 1;
-
-  // Vanilla neural net with logistic activation function.
-  BuildNetworkOptimzer<LogisticFunction,
-                       BinaryClassificationLayer,
-                       MeanSquaredErrorFunction>
-      (dataset, labels, dataset, labels, 20, 15);
+    (dataset, labels, dataset, labels, 8, 30, 0.4);
 }
 
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index cab0d3c..4f0c410 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -12,10 +12,9 @@
 #include <mlpack/methods/ann/layer/lstm_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 
-#include <mlpack/methods/ann/trainer/trainer.hpp>
 #include <mlpack/methods/ann/rnn.hpp>
 #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
-#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
 #include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
 #include <mlpack/methods/ann/init_rules/random_init.hpp>
 
@@ -24,7 +23,7 @@
 
 using namespace mlpack;
 using namespace mlpack::ann;
-
+using namespace mlpack::optimization;
   
 BOOST_AUTO_TEST_SUITE(RecurrentNetworkTest);
 
@@ -91,11 +90,11 @@ BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
    *            .     .
    *            .......
    */
-  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(1, 4);
-  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(4);
+  LinearLayer<> linearLayer0(1, 4);
+  RecurrentLayer<> recurrentLayer0(4);
   BaseLayer<LogisticFunction> inputBaseLayer;
 
-  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(4, 2);
+  LinearLayer<> hiddenLayer(4, 2);
   BaseLayer<LogisticFunction> hiddenBaseLayer;
 
   BinaryClassificationLayer classOutputLayer;
@@ -103,23 +102,27 @@ BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
   auto modules = std::tie(linearLayer0, recurrentLayer0, inputBaseLayer,
                           hiddenLayer, hiddenBaseLayer);
 
-  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
-      net(modules, classOutputLayer);
+  RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
+      MeanSquaredErrorFunction> net(modules, classOutputLayer);
 
-  // Train the network for 200 epochs.
-  Trainer<decltype(net)> trainer(net, 400, 1, 0.01);
-  trainer.Train(input, labels, input, labels);
+  SGD<decltype(net)> opt(net, 0.5, 400 * input.n_cols, -100);
 
-  // Ask the network to classify the trained input data.
-  arma::mat output;
-  for (size_t i = 0; i < input.n_cols; i++)
-  {
-    arma::mat inputSeq = input.unsafe_col(i);
-    net.Predict(inputSeq, output);
+  net.Train(input, labels, opt);
+
+  arma::mat prediction;
+  net.Predict(input, prediction);
 
-    bool b = arma::all((output == labels.unsafe_col(i)) == 1);
-    BOOST_REQUIRE_EQUAL(b, 1);
+  size_t error = 0;
+  for (size_t i = 0; i < labels.n_cols; i++)
+  {
+    if (arma::sum(arma::sum(arma::abs(prediction.col(i) - labels.col(i)))) == 0)
+    {
+      error++;
+    }
   }
+
+  double classificationError = 1 - double(error) / labels.n_cols;
+  BOOST_REQUIRE_LE(classificationError, 0.2);
 }
 
 /**
@@ -332,11 +335,10 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
    *            .......
    */
   const size_t lstmSize = 4 * 10;
-  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(7, lstmSize);
-  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(10,
-      lstmSize);
+  LinearLayer<> linearLayer0(7, lstmSize);
+  RecurrentLayer<> recurrentLayer0(10, lstmSize);
 
-  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(10, 7);
+  LinearLayer<>hiddenLayer(10, 7);
   BaseLayer<LogisticFunction> hiddenBaseLayer;
 
   BinaryClassificationLayer classOutputLayer;
@@ -344,11 +346,10 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
   auto modules = std::tie(linearLayer0, recurrentLayer0, hiddenLayer0,
                           hiddenLayer, hiddenBaseLayer);
 
-  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
-      net(modules, classOutputLayer);
+  RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
+      MeanSquaredErrorFunction> net(modules, classOutputLayer);
 
-  // Train the network for (100 * trainReberGrammarCount) epochs.
-  Trainer<decltype(net)> trainer(net, 1, 1, 0, false);
+  SGD<decltype(net)> opt(net, 0.5, 2, -200);
 
   arma::mat inputTemp, labelsTemp;
   for (size_t i = 0; i < 15; i++)
@@ -357,7 +358,7 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
     {
       inputTemp = trainInput.at(0, j);
       labelsTemp = trainLabels.at(0, j);
-      trainer.Train(inputTemp, labelsTemp, inputTemp, labelsTemp);
+      net.Train(inputTemp, labelsTemp, opt);
     }
   }
 
@@ -403,7 +404,6 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
   }
 
   error /= testReberGrammarCount;
-
   BOOST_REQUIRE_LE(error, 0.2);
 }
 
@@ -522,11 +522,10 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
    *            .......
    */
   const size_t lstmSize = 4 * 10;
-  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(10, lstmSize);
-  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(10,
-      lstmSize);
+  LinearLayer<> linearLayer0(10, lstmSize);
+  RecurrentLayer<> recurrentLayer0(10, lstmSize);
 
-  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(10, 3);
+  LinearLayer<> hiddenLayer(10, 3);
   BaseLayer<LogisticFunction> hiddenBaseLayer;
 
   BinaryClassificationLayer classOutputLayer;
@@ -534,21 +533,20 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
   auto modules = std::tie(linearLayer0, recurrentLayer0, hiddenLayer0,
                           hiddenLayer, hiddenBaseLayer);
 
-  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
-      net(modules, classOutputLayer);
+  RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
+      MeanSquaredErrorFunction> net(modules, classOutputLayer);
 
-  // Train the network for (500 * trainDistractedSequenceCount) epochs.
-  Trainer<decltype(net)> trainer(net, 1, 1, 0, false);
+  SGD<decltype(net)> opt(net, 0.03, 2, -200);
 
   arma::mat inputTemp, labelsTemp;
-  for (size_t i = 0; i < 15; i++)
+  for (size_t i = 0; i < 30; i++)
   {
     for (size_t j = 0; j < trainDistractedSequenceCount; j++)
     {
       inputTemp = trainInput.at(0, j);
       labelsTemp = trainLabels.at(0, j);
 
-      trainer.Train(inputTemp, labelsTemp, inputTemp, labelsTemp);
+      net.Train(inputTemp, labelsTemp, opt);
     }
   }
 




More information about the mlpack-git mailing list