[mlpack-git] master: It isn't guaranteed that the recurrent network will converge (using a specified number of iterations), so we test if the network is able to solve the task in one of 5 trails using different starting weights. (6147ed0)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jul 6 18:17:57 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e7b9b042d1d6e2d9895d5fa141e9c135b2d2ea57...6147ed01bab6eadcd6a5e796e259a6afacae4662
>---------------------------------------------------------------
commit 6147ed01bab6eadcd6a5e796e259a6afacae4662
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Thu Jul 7 00:17:57 2016 +0200
It isn't guaranteed that the recurrent network will converge (using a specified number of iterations), so we test if the network is able to solve the task in one of 5 trails using different starting weights.
>---------------------------------------------------------------
6147ed01bab6eadcd6a5e796e259a6afacae4662
src/mlpack/tests/recurrent_network_test.cpp | 107 ++++++++++++++++------------
1 file changed, 61 insertions(+), 46 deletions(-)
diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp
index 0ec1bcf..40d162c 100644
--- a/src/mlpack/tests/recurrent_network_test.cpp
+++ b/src/mlpack/tests/recurrent_network_test.cpp
@@ -68,58 +68,73 @@ void GenerateNoisySines(arma::mat& data,
*/
BOOST_AUTO_TEST_CASE(SequenceClassificationTest)
{
- // Generate 12 (2 * 6) noisy sines. A single sine contains 10 points/features.
- arma::mat input, labels;
- GenerateNoisySines(input, labels, 10, 6);
+ // It isn't guaranteed that the recurrent network will converge in the
+ // specified number of iterations using random weights. If this works 1 of 5
+ // times, I'm fine with that. All I want to know is that the network is able
+ // to escape from local minima and to solve the task.
+ size_t successes = 0;
- /*
- * Construct a network with 1 input unit, 4 hidden units and 2 output units.
- * The hidden layer is connected to itself. The network structure looks like:
- *
- * Input Hidden Output
- * Layer(1) Layer(4) Layer(2)
- * +-----+ +-----+ +-----+
- * | | | | | |
- * | +------>| +------>| |
- * | | ..>| | | |
- * +-----+ . +--+--+ +-----+
- * . .
- * . .
- * .......
- */
- LinearLayer<> linearLayer0(1, 4);
- RecurrentLayer<> recurrentLayer0(4);
- BaseLayer<LogisticFunction> inputBaseLayer;
-
- LinearLayer<> hiddenLayer(4, 2);
- BaseLayer<LogisticFunction> hiddenBaseLayer;
-
- BinaryClassificationLayer classOutputLayer;
-
- auto modules = std::tie(linearLayer0, recurrentLayer0, inputBaseLayer,
- hiddenLayer, hiddenBaseLayer);
-
- RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
- MeanSquaredErrorFunction> net(modules, classOutputLayer);
-
- SGD<decltype(net)> opt(net, 0.5, 500 * input.n_cols, -100);
-
- net.Train(input, labels, opt);
-
- arma::mat prediction;
- net.Predict(input, prediction);
-
- size_t error = 0;
- for (size_t i = 0; i < labels.n_cols; i++)
+ for (size_t trial = 0; trial < 5; ++trial)
{
- if (arma::sum(arma::sum(arma::abs(prediction.col(i) - labels.col(i)))) == 0)
+ // Generate 12 (2 * 6) noisy sines. A single sine contains 10 points/features.
+ arma::mat input, labels;
+ GenerateNoisySines(input, labels, 10, 6);
+
+ /*
+ * Construct a network with 1 input unit, 4 hidden units and 2 output units.
+ * The hidden layer is connected to itself. The network structure looks like:
+ *
+ * Input Hidden Output
+ * Layer(1) Layer(4) Layer(2)
+ * +-----+ +-----+ +-----+
+ * | | | | | |
+ * | +------>| +------>| |
+ * | | ..>| | | |
+ * +-----+ . +--+--+ +-----+
+ * . .
+ * . .
+ * .......
+ */
+ LinearLayer<> linearLayer0(1, 4);
+ RecurrentLayer<> recurrentLayer0(4);
+ BaseLayer<LogisticFunction> inputBaseLayer;
+
+ LinearLayer<> hiddenLayer(4, 2);
+ BaseLayer<LogisticFunction> hiddenBaseLayer;
+
+ BinaryClassificationLayer classOutputLayer;
+
+ auto modules = std::tie(linearLayer0, recurrentLayer0, inputBaseLayer,
+ hiddenLayer, hiddenBaseLayer);
+
+ RNN<decltype(modules), BinaryClassificationLayer, RandomInitialization,
+ MeanSquaredErrorFunction> net(modules, classOutputLayer);
+
+ SGD<decltype(net)> opt(net, 0.5, 500 * input.n_cols, -100);
+
+ net.Train(input, labels, opt);
+
+ arma::mat prediction;
+ net.Predict(input, prediction);
+
+ size_t error = 0;
+ for (size_t i = 0; i < labels.n_cols; i++)
+ {
+ if (arma::sum(arma::sum(arma::abs(prediction.col(i) - labels.col(i)))) == 0)
+ {
+ error++;
+ }
+ }
+
+ double classificationError = 1 - double(error) / labels.n_cols;
+ if (classificationError <= 0.2)
{
- error++;
+ ++successes;
+ break;
}
}
- double classificationError = 1 - double(error) / labels.n_cols;
- BOOST_REQUIRE_LE(classificationError, 0.2);
+ BOOST_REQUIRE_GE(successes, 1);
}
/**
More information about the mlpack-git
mailing list