[mlpack-git] master: Add distracted sequence recall test (LSTM architecture). (5e2f4f2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Mar 9 10:55:16 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3760035ca6b3ccbc07f5bdd19fa81f1436bab1d1...5e2f4f2e6b531af5791e185755d7517bd3e80a62

>---------------------------------------------------------------

commit 5e2f4f2e6b531af5791e185755d7517bd3e80a62
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Mon Mar 9 15:55:10 2015 +0100

    Add distracted sequence recall test (LSTM architecture).


>---------------------------------------------------------------

5e2f4f2e6b531af5791e185755d7517bd3e80a62
 src/mlpack/tests/recurrent_network_test.cpp | 241 +++++++++++++++++++++++++---
 1 file changed, 220 insertions(+), 21 deletions(-)

diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index b95e325..d428e17 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -2,7 +2,7 @@
  * @file feedforward_network_test.cpp
  * @author Marcus Edel
  *
- * Tests the feed forward network.
+ * Tests the recurrent network.
  */
 #include <mlpack/core.hpp>
 
@@ -13,13 +13,10 @@
 #include <mlpack/methods/ann/activation_functions/rectifier_function.hpp>
 
 #include <mlpack/methods/ann/init_rules/random_init.hpp>
-#include <mlpack/methods/ann/layer/lstm_layer.hpp>
-#include <mlpack/methods/ann/init_rules/orthogonal_init.hpp>
-#include <mlpack/methods/ann/init_rules/oivs_init.hpp>
-#include <mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
 
 #include <mlpack/methods/ann/layer/neuron_layer.hpp>
+ #include <mlpack/methods/ann/layer/lstm_layer.hpp>
 #include <mlpack/methods/ann/layer/bias_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 #include <mlpack/methods/ann/layer/multiclass_classification_layer.hpp>
@@ -35,11 +32,7 @@
 #include <mlpack/methods/ann/rnn.hpp>
 
 #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
-#include <mlpack/methods/ann/performance_functions/sse_function.hpp>
-#include <mlpack/methods/ann/performance_functions/cee_function.hpp>
-
 #include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
-#include <mlpack/methods/ann/optimizer/rpropp.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -509,7 +502,8 @@ void GenerateNextEmbeddedReber(const arma::Mat<char>& transitions,
  * Train the specified network and the construct a Reber grammar dataset.
  */
 template<typename HiddenLayerType>
-void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
+void ReberGrammarTestNetwork(HiddenLayerType& hiddenLayer0,
+                             bool embedded = false)
 {
   // Reber state transition matrix. (The last two columns are the indices to the
   // next path).
@@ -528,7 +522,6 @@ void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
   arma::field<arma::mat> trainInput(1, trainReberGrammarCount);
   arma::field<arma::mat> trainLabels(1, trainReberGrammarCount);
   arma::field<arma::mat> testInput(1, testReberGrammarCount);
-  arma::field<arma::mat> testLabels(1, testReberGrammarCount);
   arma::colvec translation;
 
   // Generate the training data.
@@ -561,15 +554,13 @@ void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
     {
       ReberTranslation(testReber[j], translation);
       testInput(0, i) = arma::join_cols(testInput(0, i), translation);
-
-      ReberTranslation(testReber[j + 1], translation);
-      testLabels(0, i) = arma::join_cols(testLabels(0, i), translation);
     }
   }
 
   /*
-   * Construct a network with 7 input units, 5 hidden units and 7 output units.
-   * The hidden layer is connected to itself. The network structure looks like:
+   * Construct a network with 7 input units, layerSize hidden units and 7 output
+   * units. The hidden layer is connected to itself. The network structure looks
+   * like:
    *
    *  Input         Hidden        Output
    * Layer(7)  Layer(layerSize)   Layer(7)
@@ -583,7 +574,6 @@ void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
    *            .......
    */
   NeuronLayer<LogisticFunction> inputLayer(7);
-  HiddenLayerType hiddenLayer0(layerSize);
   NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
   NeuronLayer<LogisticFunction> hiddenLayer1(7);
   BinaryClassificationLayer<> outputLayer;
@@ -630,7 +620,6 @@ void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
       decltype(randInit)>
       layerCon4(hiddenLayer0, hiddenLayer1, conOptimizer3, randInit);
 
-  // auto module0 = std::tie(layerCon0, layerCon2);
   auto module0 = std::tie(layerCon0, layerCon2);
   auto module1 = std::tie(layerCon4);
   auto modules = std::tie(module0, module1);
@@ -704,8 +693,11 @@ void ReberGrammarTestNetwork(const size_t layerSize, bool embedded = false)
  */
 BOOST_AUTO_TEST_CASE(ReberGrammarTest)
 {
-  ReberGrammarTestNetwork<LSTMLayer<> >(10);
-  ReberGrammarTestNetwork<NeuronLayer<LogisticFunction> >(5);
+  LSTMLayer<>  hiddenLayerLSTM(10);
+  ReberGrammarTestNetwork(hiddenLayerLSTM);
+
+  NeuronLayer<LogisticFunction> hiddenLayerLogistic(5);
+  ReberGrammarTestNetwork(hiddenLayerLogistic);
 }
 
 /**
@@ -713,7 +705,214 @@ BOOST_AUTO_TEST_CASE(ReberGrammarTest)
  */
 BOOST_AUTO_TEST_CASE(EmbeddedReberGrammarTest)
 {
-  ReberGrammarTestNetwork<LSTMLayer<> >(10, true);
+  LSTMLayer<>  hiddenLayerLSTM(10);
+  ReberGrammarTestNetwork(hiddenLayerLSTM, true);
+
+  LSTMLayer<>  hiddenLayerLSTMPeephole(10, 1, true);
+  ReberGrammarTestNetwork(hiddenLayerLSTMPeephole, true);
+}
+
+/*
+ * This sample is a simplified version of Derek D. Monner's Distracted Sequence
+ * Recall task, which involves 10 symbols:
+ *
+ * Targets: must be recognized and remembered by the network.
+ * Distractors: never need to be remembered.
+ * Prompts: direct the network to give an answer.
+ *
+ * A single trial consists of a temporal sequence of 10 input symbols. The first
+ * 8 consist of 2 randomly chosen target symbols and 6 randomly chosen
+ * distractor symbols in an random order. The remaining two symbols are two
+ * prompts, which direct the network to produce the first and second target in
+ * the sequence, in order.
+ *
+ * For more information, see the following paper.
+ *
+ * @code
+ * @misc{Monner2012,
+ *   author = {Monner, Derek and Reggia, James A},
+ *   title = {A generalized LSTM-like training algorithm for second-order
+ *   recurrent neural networks},
+ *   year = {2012}
+ * }
+ * @endcode
+ *
+ * @param input The generated input sequence.
+ * @param input The generated output sequence.
+ */
+void GenerateDistractedSequence(arma::mat& input, arma::mat& output)
+{
+  input = arma::zeros<arma::mat>(10, 10);
+  output = arma::zeros<arma::mat>(3, 10);
+
+  arma::Col<size_t> index = arma::shuffle(arma::linspace<arma::Col<size_t> >(
+      0, 7, 8));
+
+  // Set the target in the input sequence and the corresponding targets in the
+  // output sequence by following the correct order.
+  for (size_t i = 0; i < 2; i++)
+  {
+    size_t idx = rand() % 2;
+    input(idx, index(i)) = 1;
+    output(idx, index(i) > index(i == 0) ? 9 : 8) = 1;
+  }
+
+  for (size_t i = 2; i < 8; i++)
+    input(2 + rand() % 6, index(i)) = 1;
+
+
+  // Set the prompts which direct the network to give an answer.
+  input(8, 8) = 1;
+  input(9, 9) = 1;
+
+  input.reshape(input.n_elem, 1);
+  output.reshape(output.n_elem, 1);
+}
+
+/**
+ * Train the specified network and the construct distracted sequence recall
+ * dataset.
+ */
+template<typename HiddenLayerType>
+void DistractedSequenceRecallTestNetwork(HiddenLayerType& hiddenLayer0)
+{
+  const size_t trainDistractedSequenceCount = 1000;
+  const size_t testDistractedSequenceCount = 1000;
+
+  arma::field<arma::mat> trainInput(1, trainDistractedSequenceCount);
+  arma::field<arma::mat> trainLabels(1, trainDistractedSequenceCount);
+  arma::field<arma::mat> testInput(1, testDistractedSequenceCount);
+  arma::field<arma::mat> testLabels(1, testDistractedSequenceCount);
+
+  // Generate the training data.
+  for (size_t i = 0; i < trainDistractedSequenceCount; i++)
+    GenerateDistractedSequence(trainInput(0, i), trainLabels(0, i));
+
+  // Generate the test data.
+  for (size_t i = 0; i < testDistractedSequenceCount; i++)
+    GenerateDistractedSequence(testInput(0, i), testLabels(0, i));
+
+  /*
+   * Construct a network with 7 input units, layerSize hidden units and 7 output
+   * units. The hidden layer is connected to itself. The network structure looks
+   * like:
+   *
+   *  Input         Hidden        Output
+   * Layer(7)  Layer(layerSize)   Layer(7)
+   * +-----+       +-----+       +-----+
+   * |     |       |     |       |     |
+   * |     +------>|     +------>|     |
+   * |     |    ..>|     |       |     |
+   * +-----+    .  +--+--+       +-----+
+   *            .     .
+   *            .     .
+   *            .......
+   */
+  NeuronLayer<LogisticFunction> inputLayer(10);
+  NeuronLayer<IdentityFunction> recurrentLayer0(hiddenLayer0.OutputSize());
+  NeuronLayer<LogisticFunction> hiddenLayer1(3);
+  BinaryClassificationLayer<> outputLayer;
+
+  SteepestDescent< > conOptimizer0(inputLayer.OutputSize(),
+      hiddenLayer0.InputSize(), 0.1);
+  SteepestDescent< > conOptimizer2(recurrentLayer0.OutputSize(),
+      hiddenLayer0.InputSize(), 0.1);
+  SteepestDescent< > conOptimizer3(hiddenLayer0.OutputSize(),
+      hiddenLayer1.InputSize(), 0.1);
+
+  NguyenWidrowInitialization<> randInit;
+
+  FullConnection<
+      decltype(inputLayer),
+      decltype(hiddenLayer0),
+      decltype(conOptimizer0),
+      decltype(randInit)>
+      layerCon0(inputLayer, hiddenLayer0, conOptimizer0, randInit);
+
+  FullselfConnection<
+    decltype(recurrentLayer0),
+    decltype(hiddenLayer0),
+    decltype(conOptimizer2),
+    decltype(randInit)>
+    layerTypeLSTM(recurrentLayer0, hiddenLayer0, conOptimizer2, randInit);
+
+  SelfConnection<
+    decltype(recurrentLayer0),
+    decltype(hiddenLayer0),
+    decltype(conOptimizer2),
+    decltype(randInit)>
+    layerTypeBasis(recurrentLayer0, hiddenLayer0, conOptimizer2, randInit);
+
+  typename std::conditional<LayerTraits<HiddenLayerType>::IsLSTMLayer,
+      typename std::remove_reference<decltype(layerTypeLSTM)>::type,
+      typename std::remove_reference<decltype(layerTypeBasis)>::type>::type
+      layerCon2(recurrentLayer0, hiddenLayer0, conOptimizer2, randInit);
+
+  FullConnection<
+      decltype(hiddenLayer0),
+      decltype(hiddenLayer1),
+      decltype(conOptimizer3),
+      decltype(randInit)>
+      layerCon4(hiddenLayer0, hiddenLayer1, conOptimizer3, randInit);
+
+  auto module0 = std::tie(layerCon0, layerCon2);
+  auto module1 = std::tie(layerCon4);
+  auto modules = std::tie(module0, module1);
+
+  RNN<decltype(modules),
+      decltype(outputLayer),
+      MeanSquaredErrorFunction<> > net(modules, outputLayer);
+
+  // Train the network for (500 * trainDistractedSequenceCount) epochs.
+  Trainer<decltype(net)> trainer(net, 1, 1, 0, false);
+
+  arma::mat inputTemp, labelsTemp;
+  for (size_t i = 0; i < 500; i++)
+  {
+    for (size_t j = 0; j < trainDistractedSequenceCount; j++)
+    {
+      inputTemp = trainInput.at(0, j);
+      labelsTemp = trainLabels.at(0, j);
+
+      trainer.Train(inputTemp, labelsTemp, inputTemp, labelsTemp);
+    }
+  }
+
+  double error = 0;
+
+  // Ask the network to predict the targets in the given sequence at the
+  // prompts.
+  for (size_t i = 0; i < testDistractedSequenceCount; i++)
+  {
+    arma::colvec output;
+    arma::colvec input = testInput.at(0, i);
+
+    net.Predict(input, output);
+
+    if (arma::sum(arma::abs(testLabels.at(0, i) - output)) != 0)
+      error += 1;
+  }
+
+  error /= testDistractedSequenceCount;
+
+  // Can we reproduce the results from the paper. They provide an 95% accuracy
+  // on a test set of 1000 randomly selected sequences.
+  // Ensure that this is within tolerance, which is at least as good as the
+  // paper's results (plus a little bit for noise).
+  BOOST_REQUIRE_LE(error, 0.1);
+}
+
+/**
+ * Train the specified networks on the Derek D. Monner's distracted sequence
+ * recall task.
+ */
+BOOST_AUTO_TEST_CASE(DistractedSequenceRecallTest)
+{
+  LSTMLayer<>  hiddenLayerLSTM(10, 10);
+  DistractedSequenceRecallTestNetwork(hiddenLayerLSTM);
+
+  LSTMLayer<>  hiddenLayerLSTMPeephole(10, 1, true);
+  DistractedSequenceRecallTestNetwork(hiddenLayerLSTM);
 }
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list