[mlpack-git] master: Add feedforward neural network dropout test. (f530120)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jul 9 14:54:51 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/baeb3885b27d3b7dd552c638c605e034b1388cad...4a97187bbba7ce8a6191b714949dd818ef0f37d2
>---------------------------------------------------------------
commit f530120401e9b44fb6333c1166cd45026f69d8aa
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Mon Jul 6 22:01:26 2015 +0200
Add feedforward neural network dropout test.
>---------------------------------------------------------------
f530120401e9b44fb6333c1166cd45026f69d8aa
src/mlpack/tests/feedforward_network_test.cpp | 155 +++++++++++++++++++++++++-
1 file changed, 153 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index 10f1cd2..6fdbd01 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -13,9 +13,11 @@
#include <mlpack/methods/ann/layer/neuron_layer.hpp>
#include <mlpack/methods/ann/layer/bias_layer.hpp>
+#include <mlpack/methods/ann/layer/dropout_layer.hpp>
#include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
#include <mlpack/methods/ann/connections/full_connection.hpp>
+#include <mlpack/methods/ann/connections/identity_connection.hpp>
#include <mlpack/methods/ann/trainer/trainer.hpp>
@@ -171,8 +173,157 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
// Vanilla neural net with tanh activation function.
BuildVanillaNetwork<TanhFunction,
- BinaryClassificationLayer,
- MeanSquaredErrorFunction>
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 10, 200, 0.6, 20);
+}
+
+/**
+ * Train and evaluate a Dropout network with the specified structure.
+ */
+template<
+ typename PerformanceFunction,
+ typename OutputLayerType,
+ typename PerformanceFunctionType,
+ typename MatType = arma::mat,
+ typename VecType = arma::colvec
+>
+void BuildDropoutNetwork(MatType& trainData,
+ MatType& trainLabels,
+ MatType& testData,
+ MatType& testLabels,
+ const size_t hiddenLayerSize,
+ const size_t maxEpochs,
+ const double classificationErrorThreshold,
+ const double ValidationErrorThreshold)
+{
+ /*
+ * Construct a feed forward network with trainData.n_rows input nodes,
+ * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
+ * network structure looks like:
+ *
+ * Input Hidden Dropout Output
+ * Layer Layer Layer Layer
+ * +-----+ +-----+ +-----+ +-----+
+ * | | | | | | | |
+ * | +------>| +------>| +------>| |
+ * | | +>| | | | | |
+ * +-----+ | +--+--+ +-----+ +-----+
+ * |
+ * Bias |
+ * Layer |
+ * +-----+ |
+ * | | |
+ * | +-----+
+ * | |
+ * +-----+
+ */
+ BiasLayer<> biasLayer0(1);
+
+ NeuronLayer<PerformanceFunction> inputLayer(trainData.n_rows);
+ NeuronLayer<PerformanceFunction> hiddenLayer0(hiddenLayerSize);
+ DropoutLayer<> dropoutLayer0(hiddenLayerSize);
+ NeuronLayer<PerformanceFunction> hiddenLayer1(trainLabels.n_rows);
+
+ OutputLayerType outputLayer;
+
+ FullConnection<
+ decltype(inputLayer),
+ decltype(hiddenLayer0)>
+ layerCon0(inputLayer, hiddenLayer0);
+
+ FullConnection<
+ decltype(biasLayer0),
+ decltype(hiddenLayer0)>
+ layerCon1(biasLayer0, hiddenLayer0);
+
+ IdentityConnection<
+ decltype(hiddenLayer0),
+ decltype(dropoutLayer0),
+ mlpack::ann::RMSPROP,
+ arma::colvec>
+ layerCon1Dropout(hiddenLayer0, dropoutLayer0);
+
+ FullConnection<
+ decltype(dropoutLayer0),
+ decltype(hiddenLayer1)>
+ layerCon2(dropoutLayer0, hiddenLayer1);
+
+ auto module0 = std::tie(layerCon0, layerCon1);
+ auto module0Dropout = std::tie(layerCon1Dropout);
+ auto module1 = std::tie(layerCon2);
+ auto modules = std::tie(module0, module0Dropout, module1);
+
+ FFNN<decltype(modules), decltype(outputLayer), PerformanceFunctionType>
+ net(modules, outputLayer);
+
+ Trainer<decltype(net)> trainer(net, maxEpochs, 1, 0.001);
+ trainer.Train(trainData, trainLabels, testData, testLabels);
+
+ VecType prediction;
+ size_t error = 0;
+
+ for (size_t i = 0; i < testData.n_cols; i++)
+ {
+ net.Predict(testData.unsafe_col(i), prediction);
+ if (arma::sum(prediction - testLabels.unsafe_col(i)) == 0)
+ error++;
+ }
+
+ double classificationError = 1 - double(error) / testData.n_cols;
+
+ BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
+ BOOST_REQUIRE_LE(trainer.ValidationError(), ValidationErrorThreshold);
+}
+
+/**
+ * Train the dropout network on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(DropoutNetworkTest)
+{
+ // Load the dataset.
+ arma::mat dataset;
+ data::Load("thyroid_train.csv", dataset, true);
+
+ arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
+ dataset.n_cols - 1);
+ arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
+ dataset.n_rows - 1, dataset.n_cols - 1);
+
+ data::Load("thyroid_test.csv", dataset, true);
+
+ arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
+ dataset.n_cols - 1);
+ arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
+ dataset.n_rows - 1, dataset.n_cols - 1);
+
+ // Vanilla neural net with logistic activation function.
+ // Because 92 percent of the patients are not hyperthyroid the neural
+ // network must be significant better than 92%.
+ BuildDropoutNetwork<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60);
+
+ dataset.load("mnist_first250_training_4s_and_9s.arm");
+
+ // Normalize each point since these are images.
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ dataset.col(i) /= norm(dataset.col(i), 2);
+
+ arma::mat labels = arma::zeros(1, dataset.n_cols);
+ labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
+
+ // Vanilla neural net with logistic activation function.
+ BuildVanillaNetwork<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 100, 100, 0.6, 10);
+
+ // Vanilla neural net with tanh activation function.
+ BuildVanillaNetwork<TanhFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
(dataset, labels, dataset, labels, 10, 200, 0.6, 20);
}
More information about the mlpack-git
mailing list