[mlpack-git] master: Refactor for new optimizer API. (7df836c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:52 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1
>---------------------------------------------------------------
commit 7df836c2f5a2287cda82801ca20f4b4b410cf4e1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 16 14:41:15 2015 +0200
Refactor for new optimizer API.
>---------------------------------------------------------------
7df836c2f5a2287cda82801ca20f4b4b410cf4e1
src/mlpack/tests/feedforward_network_test.cpp | 103 ++++++++------------------
1 file changed, 31 insertions(+), 72 deletions(-)
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index be58169..10f1cd2 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -25,7 +25,7 @@
#include <mlpack/methods/ann/performance_functions/sse_function.hpp>
#include <mlpack/methods/ann/performance_functions/cee_function.hpp>
-#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -40,9 +40,7 @@ BOOST_AUTO_TEST_SUITE(FeedForwardNetworkTest);
* Train and evaluate a vanilla network with the specified structure.
*/
template<
- typename WeightInitRule,
typename PerformanceFunction,
- typename OptimizerType,
typename OutputLayerType,
typename PerformanceFunctionType,
typename MatType = arma::mat,
@@ -55,8 +53,7 @@ void BuildVanillaNetwork(MatType& trainData,
const size_t hiddenLayerSize,
const size_t maxEpochs,
const double classificationErrorThreshold,
- const double ValidationErrorThreshold,
- WeightInitRule weightInitRule = WeightInitRule())
+ const double ValidationErrorThreshold)
{
/*
* Construct a feed forward network with trainData.n_rows input nodes,
@@ -87,30 +84,20 @@ void BuildVanillaNetwork(MatType& trainData,
OutputLayerType outputLayer;
- OptimizerType conOptimizer0(trainData.n_rows, hiddenLayerSize);
- OptimizerType conOptimizer1(1, hiddenLayerSize);
- OptimizerType conOptimizer2(hiddenLayerSize, trainLabels.n_rows);
-
FullConnection<
decltype(inputLayer),
- decltype(hiddenLayer0),
- decltype(conOptimizer0),
- decltype(weightInitRule)>
- layerCon0(inputLayer, hiddenLayer0, conOptimizer0, weightInitRule);
+ decltype(hiddenLayer0)>
+ layerCon0(inputLayer, hiddenLayer0);
FullConnection<
decltype(biasLayer0),
- decltype(hiddenLayer0),
- decltype(conOptimizer1),
- decltype(weightInitRule)>
- layerCon1(biasLayer0, hiddenLayer0, conOptimizer1, weightInitRule);
+ decltype(hiddenLayer0)>
+ layerCon1(biasLayer0, hiddenLayer0);
FullConnection<
decltype(hiddenLayer0),
- decltype(hiddenLayer1),
- decltype(conOptimizer2),
- decltype(weightInitRule)>
- layerCon2(hiddenLayer0, hiddenLayer1, conOptimizer2, weightInitRule);
+ decltype(hiddenLayer1)>
+ layerCon2(hiddenLayer0, hiddenLayer1);
auto module0 = std::tie(layerCon0, layerCon1);
auto module1 = std::tie(layerCon2);
@@ -159,23 +146,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);
- RandomInitialization randInitA(1, 2);
-
// Vanilla neural net with logistic activation function.
// Because 92 percent of the patients are not hyperthyroid the neural
// network must be significant better than 92%.
- BuildVanillaNetwork<RandomInitialization,
- LogisticFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<LogisticFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (trainData, trainLabels, testData, testLabels, 4, 500,
- 0.1, 60, randInitA);
+ (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60);
dataset.load("mnist_first250_training_4s_and_9s.arm");
- RandomInitialization randInitB(-0.5, 0.5);
-
// Normalize each point since these are images.
for (size_t i = 0; i < dataset.n_cols; ++i)
dataset.col(i) /= norm(dataset.col(i), 2);
@@ -184,20 +164,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
// Vanilla neural net with logistic activation function.
- BuildVanillaNetwork<RandomInitialization,
- LogisticFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<LogisticFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (dataset, labels, dataset, labels, 100, 100, 0.6, 10, randInitB);
+ (dataset, labels, dataset, labels, 100, 100, 0.6, 10);
// Vanilla neural net with tanh activation function.
- BuildVanillaNetwork<RandomInitialization,
- TanhFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<TanhFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (dataset, labels, dataset, labels, 10, 200, 0.6, 20, randInitB);
+ (dataset, labels, dataset, labels, 10, 200, 0.6, 20);
}
/**
@@ -208,28 +184,22 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
arma::mat input;
arma::mat labels;
- RandomInitialization randInit(0.5, 1);
-
// Test on a non-linearly separable dataset (XOR).
input << 0 << 1 << 1 << 0 << arma::endr
<< 1 << 0 << 1 << 0 << arma::endr;
labels << 0 << 0 << 1 << 1;
// Vanilla neural net with logistic activation function.
- BuildVanillaNetwork<RandomInitialization,
- LogisticFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<LogisticFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+ (input, labels, input, labels, 4, 0, 0, 0.01);
// Vanilla neural net with tanh activation function.
- BuildVanillaNetwork<RandomInitialization,
- TanhFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<TanhFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+ (input, labels, input, labels, 4, 0, 0, 0.01);
// Test on a linearly separable dataset (AND).
input << 0 << 1 << 1 << 0 << arma::endr
@@ -237,20 +207,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
labels << 0 << 0 << 1 << 0;
// vanilla neural net with sigmoid activation function.
- BuildVanillaNetwork<RandomInitialization,
- LogisticFunction,
- SteepestDescent<>,
- BinaryClassificationLayer,
- MeanSquaredErrorFunction>
- (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+ BuildVanillaNetwork<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (input, labels, input, labels, 4, 0, 0, 0.01);
// Vanilla neural net with tanh activation function.
- BuildVanillaNetwork<RandomInitialization,
- TanhFunction,
- SteepestDescent<>,
+ BuildVanillaNetwork<TanhFunction,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
- (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+ (input, labels, input, labels, 4, 0, 0, 0.01);
}
/**
@@ -260,7 +226,6 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
template<
typename WeightInitRule,
typename PerformanceFunction,
- typename OptimizerType,
typename OutputLayerType,
typename PerformanceFunctionType,
typename MatType = arma::mat
@@ -273,8 +238,7 @@ void BuildNetworkOptimzer(MatType& trainData,
size_t epochs,
WeightInitRule weightInitRule = WeightInitRule())
{
- /*
- * Construct a feed forward network with trainData.n_rows input nodes,
+ /* Construct a feed forward network with trainData.n_rows input nodes,
* hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
* network structure looks like:
*
@@ -302,30 +266,26 @@ void BuildNetworkOptimzer(MatType& trainData,
OutputLayerType outputLayer;
- OptimizerType conOptimizer0(trainData.n_rows, hiddenLayerSize);
- OptimizerType conOptimizer1(1, hiddenLayerSize);
- OptimizerType conOptimizer2(hiddenLayerSize, trainLabels.n_rows);
-
FullConnection<
decltype(inputLayer),
decltype(hiddenLayer0),
- decltype(conOptimizer0),
+ mlpack::ann::RMSPROP,
decltype(weightInitRule)>
- layerCon0(inputLayer, hiddenLayer0, conOptimizer0, weightInitRule);
+ layerCon0(inputLayer, hiddenLayer0, weightInitRule);
FullConnection<
decltype(biasLayer0),
decltype(hiddenLayer0),
- decltype(conOptimizer1),
+ mlpack::ann::RMSPROP,
decltype(weightInitRule)>
- layerCon1(biasLayer0, hiddenLayer0, conOptimizer1, weightInitRule);
+ layerCon1(biasLayer0, hiddenLayer0, weightInitRule);
FullConnection<
decltype(hiddenLayer0),
decltype(hiddenLayer1),
- decltype(conOptimizer2),
+ mlpack::ann::RMSPROP,
decltype(weightInitRule)>
- layerCon2(hiddenLayer0, hiddenLayer1, conOptimizer2, weightInitRule);
+ layerCon2(hiddenLayer0, hiddenLayer1, weightInitRule);
auto module0 = std::tie(layerCon0, layerCon1);
auto module1 = std::tie(layerCon2);
@@ -370,7 +330,6 @@ BOOST_AUTO_TEST_CASE(NetworkDecreasingErrorTest)
// Vanilla neural net with logistic activation function.
BuildNetworkOptimzer<RandomInitialization,
LogisticFunction,
- SteepestDescent<>,
BinaryClassificationLayer,
MeanSquaredErrorFunction>
(dataset, labels, dataset, labels, 100, 50, randInitB);
More information about the mlpack-git
mailing list