[mlpack-git] master: Refactor recurrent network test for new network API. (7388de7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 13 12:45:59 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2...7388de71d5398103ee3a0b32b4026902a40a67b3

>---------------------------------------------------------------

commit 7388de71d5398103ee3a0b32b4026902a40a67b3
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Fri Nov 13 18:44:53 2015 +0100

    Refactor recurrent network test for new network API.


>---------------------------------------------------------------

7388de71d5398103ee3a0b32b4026902a40a67b3
 src/mlpack/tests/recurrent_network_test.cpp | 448 +++++-----------------------
 1 file changed, 77 insertions(+), 371 deletions(-)

diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index 5649831..3fd96d9 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -6,33 +6,18 @@
  */
 #include <mlpack/core.hpp>
 
-#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
-#include <mlpack/methods/ann/activation_functions/identity_function.hpp>
-#include <mlpack/methods/ann/activation_functions/softsign_function.hpp>
-#include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
-#include <mlpack/methods/ann/activation_functions/rectifier_function.hpp>
-
-#include <mlpack/methods/ann/init_rules/random_init.hpp>
-#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-
-#include <mlpack/methods/ann/layer/neuron_layer.hpp>
+#include <mlpack/methods/ann/layer/linear_layer.hpp>
+#include <mlpack/methods/ann/layer/recurrent_layer.hpp>
+#include <mlpack/methods/ann/layer/base_layer.hpp>
 #include <mlpack/methods/ann/layer/lstm_layer.hpp>
-#include <mlpack/methods/ann/layer/bias_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
-#include <mlpack/methods/ann/layer/multiclass_classification_layer.hpp>
-
-#include <mlpack/methods/ann/connections/full_connection.hpp>
-#include <mlpack/methods/ann/connections/self_connection.hpp>
-#include <mlpack/methods/ann/connections/fullself_connection.hpp>
-#include <mlpack/methods/ann/connections/connection_traits.hpp>
 
 #include <mlpack/methods/ann/trainer/trainer.hpp>
-
-#include <mlpack/methods/ann/ffnn.hpp>
 #include <mlpack/methods/ann/rnn.hpp>
-
 #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
-#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
+#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -106,244 +91,38 @@ BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
    *            .     .
    *            .......
    */
-  NeuronLayer<LogisticFunction> inputLayer(1);
-  NeuronLayer<LogisticFunction> hiddenLayer0(4);
-  NeuronLayer<LogisticFunction> recurrentLayer0(hiddenLayer0.InputSize());
-  NeuronLayer<LogisticFunction> hiddenLayer1(2);
-  BinaryClassificationLayer outputLayer;
-
-  RandomInitialization randInit(-0.5, 0.5);
-
-  FullConnection<
-      decltype(inputLayer),
-      decltype(hiddenLayer0),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon0(inputLayer, hiddenLayer0, randInit);
-
-  SelfConnection<
-    decltype(recurrentLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(randInit)>
-    layerCon2(recurrentLayer0, hiddenLayer0, randInit);
-
-  FullConnection<
-      decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon4(hiddenLayer0, hiddenLayer1, randInit);
-
-  auto module0 = std::tie(layerCon0, layerCon2);
-  auto module1 = std::tie(layerCon4);
-  auto modules = std::tie(module0, module1);
-
-  RNN<decltype(modules),
-      decltype(outputLayer),
-      MeanSquaredErrorFunction> net(modules, outputLayer);
-
-  // Train the network for 1000 epochs.
-  Trainer<decltype(net)> trainer(net, 1000);
-  trainer.Train(input, labels, input, labels);
-
-  // Ask the network to classify the trained input data.
-  arma::colvec output;
-  for (size_t i = 0; i < input.n_cols; i++)
-  {
-    net.Predict(input.unsafe_col(i), output);
+  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(1, 4);
+  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(4);
+  BaseLayer<LogisticFunction> inputBaseLayer;
 
-    bool b = arma::all((output == labels.unsafe_col(i)) == 1);
-    BOOST_REQUIRE_EQUAL(b, 1);
-  }
-}
+  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(4, 2);
+  BaseLayer<LogisticFunction> hiddenBaseLayer;
 
-/**
- * Train and evaluate a vanilla feed forward network and a recurrent network
- * with the specified structure and compare the two networks output and overall
- * error.
- */
-template<
-    typename WeightInitRule,
-    typename PerformanceFunction,
-    typename OutputLayerType,
-    typename PerformanceFunctionType,
-    typename MatType = arma::mat
->
-void CompareVanillaNetworks(MatType& trainData,
-                            MatType& trainLabels,
-                            MatType& testData,
-                            MatType& testLabels,
-                            const size_t hiddenLayerSize,
-                            const size_t maxEpochs,
-                            WeightInitRule weightInitRule = WeightInitRule())
-{
-  BiasLayer<> biasLayer0(1);
-
-  NeuronLayer<PerformanceFunction> inputLayer(trainData.n_rows);
-  NeuronLayer<PerformanceFunction> hiddenLayer0(hiddenLayerSize);
-  NeuronLayer<PerformanceFunction> hiddenLayer1(trainLabels.n_rows);
-
-  OutputLayerType outputLayer;
-
-  FullConnection<
-    decltype(inputLayer),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(weightInitRule)>
-    ffnLayerCon0(inputLayer, hiddenLayer0, weightInitRule);
-
-  FullConnection<
-    decltype(inputLayer),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(weightInitRule)>
-    rnnLayerCon0(inputLayer, hiddenLayer0, weightInitRule);
-
-  FullConnection<
-    decltype(biasLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(weightInitRule)>
-    ffnLayerCon1(biasLayer0, hiddenLayer0, weightInitRule);
-
-  FullConnection<
-    decltype(biasLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(weightInitRule)>
-    rnnLayerCon1(biasLayer0, hiddenLayer0, weightInitRule);
-
-  FullConnection<
-      decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      mlpack::ann::RMSPROP,
-      decltype(weightInitRule)>
-      ffnLayerCon2(hiddenLayer0, hiddenLayer1, weightInitRule);
-
-  FullConnection<
-      decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      mlpack::ann::RMSPROP,
-      decltype(weightInitRule)>
-      rnnLayerCon2(hiddenLayer0, hiddenLayer1, weightInitRule);
-
-  auto ffnModule0 = std::tie(ffnLayerCon0, ffnLayerCon1);
-  auto ffnModule1 = std::tie(ffnLayerCon2);
-  auto ffnModules = std::tie(ffnModule0, ffnModule1);
-
-  auto rnnModule0 = std::tie(rnnLayerCon0, rnnLayerCon1);
-  auto rnnModule1 = std::tie(rnnLayerCon2);
-  auto rnnModules = std::tie(rnnModule0, rnnModule1);
+  BinaryClassificationLayer classOutputLayer;
 
-  /*
-   * Construct a feed forward network with trainData.n_rows input units,
-   * hiddenLayerSize hidden units and trainLabels.n_rows output units. The
-   * network structure looks like:
-   *
-   *  Input         Hidden        Output
-   *  Layer         Layer         Layer
-   * +-----+       +-----+       +-----+
-   * |     |       |     |       |     |
-   * |     +------>|     +------>|     |
-   * |     |       |     |       |     |
-   * +-----+       +--+--+       +-----+
-   */
-  FFNN<decltype(ffnModules), decltype(outputLayer), PerformanceFunctionType>
-      ffn(ffnModules, outputLayer);
+  auto modules = std::tie(linearLayer0, recurrentLayer0, inputBaseLayer,
+                          hiddenLayer, hiddenBaseLayer);
 
-  /*
-   * Construct a recurrent network with trainData.n_rows input units,
-   * hiddenLayerSize hidden units and trainLabels.n_rows output units. The
-   * hidden layer is connected to itself. The network structure looks like:
-   *
-   *  Input         Hidden        Output
-   *  Layer         Layer         Layer
-   * +-----+       +-----+       +-----+
-   * |     |       |     |       |     |
-   * |     +------>|     +------>|     |
-   * |     |    ..>|     |       |     |
-   * +-----+    .  +--+--+       +-----+
-   *            .     .
-   *            .     .
-   *            .......
-   */
-  RNN<decltype(rnnModules), decltype(outputLayer), PerformanceFunctionType>
-      rnn(rnnModules, outputLayer);
+  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
+      net(modules, classOutputLayer);
 
-  // Train the network for maxEpochs epochs or until we reach a validation error
-  // of less then 0.001.
-  Trainer<decltype(ffn)> ffnTrainer(ffn, maxEpochs, 1, 0.001, false);
-  Trainer<decltype(rnn)> rnnTrainer(rnn, maxEpochs, 1, 0.001, false);
+  // Train the network for 200 epochs.
+  Trainer<decltype(net)> trainer(net, 200, 1, 0.01);
+  trainer.Train(input, labels, input, labels);
 
-  for (size_t i = 0; i < 5; i++)
+  // Ask the network to classify the trained input data.
+  arma::mat output;
+  for (size_t i = 0; i < input.n_cols; i++)
   {
-    rnnTrainer.Train(trainData, trainLabels, testData, testLabels);
-    ffnTrainer.Train(trainData, trainLabels, testData, testLabels);
-
-    if (!arma::is_finite(ffnTrainer.ValidationError()))
-      continue;
+    arma::mat inputSeq = input.unsafe_col(i);
+    net.Predict(inputSeq, output);
 
-    BOOST_REQUIRE_CLOSE(ffnTrainer.ValidationError(),
-        rnnTrainer.ValidationError(), 1e-3);
+    bool b = arma::all((output == labels.unsafe_col(i)) == 1);
+    BOOST_REQUIRE_EQUAL(b, 1);
   }
 }
 
 /**
- * Train a vanilla feed forward and recurrent network on a sequence with len
- * one. Ideally the recurrent network should produce the same output as the
- * recurrent network. The self connection shouldn't affect the output when using
- * a sequence with a length of one.
- */
-BOOST_AUTO_TEST_CASE(FeedForwardRecurrentNetworkTest)
-{
-  arma::mat input;
-  arma::mat labels;
-
-  RandomInitialization randInit(1, 1);
-
-  // Test on a non-linearly separable dataset (XOR).
-  input << 0 << 1 << 1 << 0 << arma::endr
-        << 1 << 0 << 1 << 0 << arma::endr;
-  labels << 0 << 0 << 1 << 1;
-
-  // Vanilla neural net with logistic activation function.
-  CompareVanillaNetworks<RandomInitialization,
-                      LogisticFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-      (input, labels, input, labels, 10, 10, randInit);
-
-  // Vanilla neural net with identity activation function.
-  CompareVanillaNetworks<RandomInitialization,
-                      IdentityFunction,
-                      BinaryClassificationLayer,
-                      MeanSquaredErrorFunction>
-      (input, labels, input, labels, 1, 1, randInit);
-
-  // Vanilla neural net with rectifier activation function.
-  CompareVanillaNetworks<RandomInitialization,
-                    RectifierFunction,
-                    BinaryClassificationLayer,
-                    MeanSquaredErrorFunction>
-    (input, labels, input, labels, 10, 10, randInit);
-
-  // Vanilla neural net with softsign activation function.
-  CompareVanillaNetworks<RandomInitialization,
-                    SoftsignFunction,
-                    BinaryClassificationLayer,
-                    MeanSquaredErrorFunction>
-    (input, labels, input, labels, 10, 10, randInit);
-
-  // Vanilla neural net with tanh activation function.
-  CompareVanillaNetworks<RandomInitialization,
-                    TanhFunction,
-                    BinaryClassificationLayer,
-                    MeanSquaredErrorFunction>
-    (input, labels, input, labels, 10, 10, randInit);
-}
-
-/**
  * Generate a random Reber grammar.
  *
  * For more information, see the following thesis.
@@ -552,59 +331,27 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
    *            .     .
    *            .......
    */
-  NeuronLayer<LogisticFunction> inputLayer(7);
-  NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
-  NeuronLayer<LogisticFunction> hiddenLayer1(7);
-  BinaryClassificationLayer outputLayer;
-
-  NguyenWidrowInitialization randInit;
-
-  FullConnection<
-      decltype(inputLayer),
-      decltype(hiddenLayer0),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon0(inputLayer, hiddenLayer0, randInit);
-
-  FullselfConnection<
-    decltype(recurrentLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(randInit)>
-    layerTypeLSTM(recurrentLayer0, hiddenLayer0, randInit);
-
-  SelfConnection<
-    decltype(recurrentLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(randInit)>
-    layerTypeBasis(recurrentLayer0, hiddenLayer0, randInit);
-
-  typename std::conditional<LayerTraits<HiddenLayerType>::IsLSTMLayer,
-      typename std::remove_reference<decltype(layerTypeLSTM)>::type,
-      typename std::remove_reference<decltype(layerTypeBasis)>::type>::type
-      layerCon2(recurrentLayer0, hiddenLayer0, randInit);
-
-  FullConnection<
-      decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon4(hiddenLayer0, hiddenLayer1, randInit);
-
-  auto module0 = std::tie(layerCon0, layerCon2);
-  auto module1 = std::tie(layerCon4);
-  auto modules = std::tie(module0, module1);
-
-  RNN<decltype(modules),
-      decltype(outputLayer),
-      MeanSquaredErrorFunction> net(modules, outputLayer);
-
-  // Train the network for (500 * trainReberGrammarCount) epochs.
+  const size_t lstmSize = 4 * 10;
+  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(7, lstmSize);
+  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(10,
+      lstmSize);
+
+  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(10, 7);
+  BaseLayer<LogisticFunction> hiddenBaseLayer;
+
+  BinaryClassificationLayer classOutputLayer;
+
+  auto modules = std::tie(linearLayer0, recurrentLayer0, hiddenLayer0,
+                          hiddenLayer, hiddenBaseLayer);
+
+  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
+      net(modules, classOutputLayer);
+
+  // Train the network for (100 * trainReberGrammarCount) epochs.
   Trainer<decltype(net)> trainer(net, 1, 1, 0, false);
 
   arma::mat inputTemp, labelsTemp;
-  for (size_t i = 0; i < 100; i++)
+  for (size_t i = 0; i < 15; i++)
   {
     for (size_t j = 0; j < trainReberGrammarCount; j++)
     {
@@ -619,8 +366,8 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
   // Ask the network to predict the next Reber grammar in the given sequence.
   for (size_t i = 0; i < testReberGrammarCount; i++)
   {
-    arma::colvec output;
-    arma::colvec input = testInput.at(0, i);
+    arma::mat output;
+    arma::mat input = testInput.at(0, i);
 
     net.Predict(input, output);
 
@@ -630,16 +377,16 @@ void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
     size_t reberError = 0;
     for (size_t j = 0; j < (output.n_elem / reberGrammerSize); j++)
     {
-      if (arma::sum(output.subvec(j * reberGrammerSize, (j + 1) *
-          reberGrammerSize - 1)) != 1) break;
+      if (arma::sum(arma::sum(output.submat(j * reberGrammerSize, 0, (j + 1) *
+          reberGrammerSize - 1, 0))) != 1) break;
 
       char predictedSymbol, inputSymbol;
       std::string reberChoices;
 
-      ReberReverseTranslation(output.subvec(j * reberGrammerSize, (j + 1) *
-          reberGrammerSize - 1), predictedSymbol);
-      ReberReverseTranslation(input.subvec(j * reberGrammerSize, (j + 1) *
-          reberGrammerSize - 1), inputSymbol);
+      ReberReverseTranslation(output.submat(j * reberGrammerSize, 0, (j + 1) *
+          reberGrammerSize - 1, 0), predictedSymbol);
+      ReberReverseTranslation(input.submat(j * reberGrammerSize, 0, (j + 1) *
+          reberGrammerSize - 1, 0), inputSymbol);
       inputReber += inputSymbol;
 
       if (embedded)
@@ -667,9 +414,6 @@ BOOST_AUTO_TEST_CASE(ReberGrammarTest)
 {
   LSTMLayer<> hiddenLayerLSTM(10);
   ReberGrammarTestNetwork(hiddenLayerLSTM);
-
-  NeuronLayer<LogisticFunction> hiddenLayerLogistic(5);
-  ReberGrammarTestNetwork(hiddenLayerLogistic);
 }
 
 /**
@@ -679,9 +423,6 @@ BOOST_AUTO_TEST_CASE(EmbeddedReberGrammarTest)
 {
   LSTMLayer<> hiddenLayerLSTM(10);
   ReberGrammarTestNetwork(hiddenLayerLSTM, true);
-
-  LSTMLayer<> hiddenLayerLSTMPeephole(10, 1, true);
-  ReberGrammarTestNetwork(hiddenLayerLSTMPeephole, true);
 }
 
 /*
@@ -765,12 +506,12 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
     GenerateDistractedSequence(testInput(0, i), testLabels(0, i));
 
   /*
-   * Construct a network with 7 input units, layerSize hidden units and 7 output
-   * units. The hidden layer is connected to itself. The network structure looks
-   * like:
+   * Construct a network with 10 input units, layerSize hidden units and 3
+   * output units. The hidden layer is connected to itself. The network
+   * structure looks like:
    *
    *  Input         Hidden        Output
-   * Layer(7)  Layer(layerSize)   Layer(7)
+   * Layer(10)  Layer(layerSize)   Layer(3)
    * +-----+       +-----+       +-----+
    * |     |       |     |       |     |
    * |     +------>|     +------>|     |
@@ -780,59 +521,27 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
    *            .     .
    *            .......
    */
-  NeuronLayer<LogisticFunction> inputLayer(10);
-  NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
-  NeuronLayer<LogisticFunction> hiddenLayer1(3);
-  BinaryClassificationLayer outputLayer;
-
-  NguyenWidrowInitialization randInit;
-
-  FullConnection<
-      decltype(inputLayer),
-      decltype(hiddenLayer0),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon0(inputLayer, hiddenLayer0, randInit);
-
-  FullselfConnection<
-    decltype(recurrentLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(randInit)>
-    layerTypeLSTM(recurrentLayer0, hiddenLayer0, randInit);
-
-  SelfConnection<
-    decltype(recurrentLayer0),
-    decltype(hiddenLayer0),
-    mlpack::ann::RMSPROP,
-    decltype(randInit)>
-    layerTypeBasis(recurrentLayer0, hiddenLayer0, randInit);
-
-  typename std::conditional<LayerTraits<HiddenLayerType>::IsLSTMLayer,
-      typename std::remove_reference<decltype(layerTypeLSTM)>::type,
-      typename std::remove_reference<decltype(layerTypeBasis)>::type>::type
-      layerCon2(recurrentLayer0, hiddenLayer0, randInit);
-
-  FullConnection<
-      decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      mlpack::ann::RMSPROP,
-      decltype(randInit)>
-      layerCon4(hiddenLayer0, hiddenLayer1, randInit);
-
-  auto module0 = std::tie(layerCon0, layerCon2);
-  auto module1 = std::tie(layerCon4);
-  auto modules = std::tie(module0, module1);
-
-  RNN<decltype(modules),
-      decltype(outputLayer),
-      MeanSquaredErrorFunction> net(modules, outputLayer);
+  const size_t lstmSize = 4 * 10;
+  LinearLayer<SteepestDescent, RandomInitialization> linearLayer0(10, lstmSize);
+  RecurrentLayer<SteepestDescent, RandomInitialization> recurrentLayer0(10,
+      lstmSize);
+
+  LinearLayer<SteepestDescent, RandomInitialization> hiddenLayer(10, 3);
+  BaseLayer<LogisticFunction> hiddenBaseLayer;
+
+  BinaryClassificationLayer classOutputLayer;
+
+  auto modules = std::tie(linearLayer0, recurrentLayer0, hiddenLayer0,
+                          hiddenLayer, hiddenBaseLayer);
+
+  RNN<decltype(modules), BinaryClassificationLayer, MeanSquaredErrorFunction>
+      net(modules, classOutputLayer);
 
   // Train the network for (500 * trainDistractedSequenceCount) epochs.
   Trainer<decltype(net)> trainer(net, 1, 1, 0, false);
 
   arma::mat inputTemp, labelsTemp;
-  for (size_t i = 0; i < 100; i++)
+  for (size_t i = 0; i < 15; i++)
   {
     for (size_t j = 0; j < trainDistractedSequenceCount; j++)
     {
@@ -849,12 +558,12 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
   // prompts.
   for (size_t i = 0; i < testDistractedSequenceCount; i++)
   {
-    arma::colvec output;
-    arma::colvec input = testInput.at(0, i);
+    arma::mat output;
+    arma::mat input = testInput.at(0, i);
 
     net.Predict(input, output);
 
-    if (arma::sum(arma::abs(testLabels.at(0, i) - output)) != 0)
+    if (arma::accu(arma::abs(testLabels.at(0, i) - output)) != 0)
       error += 1;
   }
 
@@ -873,11 +582,8 @@ void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
  */
 BOOST_AUTO_TEST_CASE(DistractedSequenceRecallTest)
 {
-  LSTMLayer<> hiddenLayerLSTM(10, 10);
-  DistractedSequenceRecallTestNetwork(hiddenLayerLSTM);
-
-  LSTMLayer<> hiddenLayerLSTMPeephole(10, 1, true);
-  DistractedSequenceRecallTestNetwork(hiddenLayerLSTM);
+  LSTMLayer<> hiddenLayerLSTMPeephole(10, true);
+  DistractedSequenceRecallTestNetwork(hiddenLayerLSTMPeephole);
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list