[mlpack-git] master: Refactor for new optimizer API. (7df836c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:52 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1

>---------------------------------------------------------------

commit 7df836c2f5a2287cda82801ca20f4b4b410cf4e1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 16 14:41:15 2015 +0200

    Refactor for new optimizer API.


>---------------------------------------------------------------

7df836c2f5a2287cda82801ca20f4b4b410cf4e1
 src/mlpack/tests/feedforward_network_test.cpp | 103 ++++++++------------------
 1 file changed, 31 insertions(+), 72 deletions(-)

diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index be58169..10f1cd2 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -25,7 +25,7 @@
 #include <mlpack/methods/ann/performance_functions/sse_function.hpp>
 #include <mlpack/methods/ann/performance_functions/cee_function.hpp>
 
-#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -40,9 +40,7 @@ BOOST_AUTO_TEST_SUITE(FeedForwardNetworkTest);
  * Train and evaluate a vanilla network with the specified structure.
  */
 template<
-    typename WeightInitRule,
     typename PerformanceFunction,
-    typename OptimizerType,
     typename OutputLayerType,
     typename PerformanceFunctionType,
     typename MatType = arma::mat,
@@ -55,8 +53,7 @@ void BuildVanillaNetwork(MatType& trainData,
                          const size_t hiddenLayerSize,
                          const size_t maxEpochs,
                          const double classificationErrorThreshold,
-                         const double ValidationErrorThreshold,
-                         WeightInitRule weightInitRule = WeightInitRule())
+                         const double ValidationErrorThreshold)
 {
   /*
    * Construct a feed forward network with trainData.n_rows input nodes,
@@ -87,30 +84,20 @@ void BuildVanillaNetwork(MatType& trainData,
 
   OutputLayerType outputLayer;
 
-  OptimizerType conOptimizer0(trainData.n_rows, hiddenLayerSize);
-  OptimizerType conOptimizer1(1, hiddenLayerSize);
-  OptimizerType conOptimizer2(hiddenLayerSize, trainLabels.n_rows);
-
   FullConnection<
     decltype(inputLayer),
-    decltype(hiddenLayer0),
-    decltype(conOptimizer0),
-    decltype(weightInitRule)>
-    layerCon0(inputLayer, hiddenLayer0, conOptimizer0, weightInitRule);
+    decltype(hiddenLayer0)>
+    layerCon0(inputLayer, hiddenLayer0);
 
   FullConnection<
     decltype(biasLayer0),
-    decltype(hiddenLayer0),
-    decltype(conOptimizer1),
-    decltype(weightInitRule)>
-    layerCon1(biasLayer0, hiddenLayer0, conOptimizer1, weightInitRule);
+    decltype(hiddenLayer0)>
+    layerCon1(biasLayer0, hiddenLayer0);
 
   FullConnection<
       decltype(hiddenLayer0),
-      decltype(hiddenLayer1),
-      decltype(conOptimizer2),
-      decltype(weightInitRule)>
-      layerCon2(hiddenLayer0, hiddenLayer1, conOptimizer2, weightInitRule);
+      decltype(hiddenLayer1)>
+      layerCon2(hiddenLayer0, hiddenLayer1);
 
   auto module0 = std::tie(layerCon0, layerCon1);
   auto module1 = std::tie(layerCon2);
@@ -159,23 +146,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
       dataset.n_rows - 1, dataset.n_cols - 1);
 
-  RandomInitialization randInitA(1, 2);
-
   // Vanilla neural net with logistic activation function.
   // Because 92 percent of the patients are not hyperthyroid the neural
   // network must be significant better than 92%.
-  BuildVanillaNetwork<RandomInitialization,
-                      LogisticFunction,
-                      SteepestDescent<>,
+  BuildVanillaNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (trainData, trainLabels, testData, testLabels, 4, 500,
-          0.1, 60, randInitA);
+      (trainData, trainLabels, testData, testLabels, 4, 500, 0.1, 60);
 
   dataset.load("mnist_first250_training_4s_and_9s.arm");
 
-  RandomInitialization randInitB(-0.5, 0.5);
-
   // Normalize each point since these are images.
   for (size_t i = 0; i < dataset.n_cols; ++i)
     dataset.col(i) /= norm(dataset.col(i), 2);
@@ -184,20 +164,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
   labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
 
   // Vanilla neural net with logistic activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                      LogisticFunction,
-                      SteepestDescent<>,
+  BuildVanillaNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (dataset, labels, dataset, labels, 100, 100, 0.6, 10, randInitB);
+      (dataset, labels, dataset, labels, 100, 100, 0.6, 10);
 
   // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                    TanhFunction,
-                    SteepestDescent<>,
+  BuildVanillaNetwork<TanhFunction,
                     BinaryClassificationLayer,
                     MeanSquaredErrorFunction>
-    (dataset, labels, dataset, labels, 10, 200, 0.6, 20, randInitB);
+    (dataset, labels, dataset, labels, 10, 200, 0.6, 20);
 }
 
 /**
@@ -208,28 +184,22 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   arma::mat input;
   arma::mat labels;
 
-  RandomInitialization randInit(0.5, 1);
-
   // Test on a non-linearly separable dataset (XOR).
   input << 0 << 1 << 1 << 0 << arma::endr
         << 1 << 0 << 1 << 0 << arma::endr;
   labels << 0 << 0 << 1 << 1;
 
   // Vanilla neural net with logistic activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                      LogisticFunction,
-                      SteepestDescent<>,
+  BuildVanillaNetwork<LogisticFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+      (input, labels, input, labels, 4, 0, 0, 0.01);
 
   // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                      TanhFunction,
-                      SteepestDescent<>,
+  BuildVanillaNetwork<TanhFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+      (input, labels, input, labels, 4, 0, 0, 0.01);
 
   // Test on a linearly separable dataset (AND).
   input << 0 << 1 << 1 << 0 << arma::endr
@@ -237,20 +207,16 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
   labels << 0 << 0 << 1 << 0;
 
   // vanilla neural net with sigmoid activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                    LogisticFunction,
-                    SteepestDescent<>,
-                    BinaryClassificationLayer,
-                    MeanSquaredErrorFunction>
-    (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+  BuildVanillaNetwork<LogisticFunction,
+                      BinaryClassificationLayer,
+                      MeanSquaredErrorFunction>
+    (input, labels, input, labels, 4, 0, 0, 0.01);
 
   // Vanilla neural net with tanh activation function.
-  BuildVanillaNetwork<RandomInitialization,
-                      TanhFunction,
-                      SteepestDescent<>,
+  BuildVanillaNetwork<TanhFunction,
                       BinaryClassificationLayer,
                       MeanSquaredErrorFunction>
-      (input, labels, input, labels, 4, 0, 0, 0.01, randInit);
+      (input, labels, input, labels, 4, 0, 0, 0.01);
 }
 
 /**
@@ -260,7 +226,6 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkConvergenceTest)
 template<
     typename WeightInitRule,
     typename PerformanceFunction,
-    typename OptimizerType,
     typename OutputLayerType,
     typename PerformanceFunctionType,
     typename MatType = arma::mat
@@ -273,8 +238,7 @@ void BuildNetworkOptimzer(MatType& trainData,
                           size_t epochs,
                           WeightInitRule weightInitRule = WeightInitRule())
 {
-  /*
-   * Construct a feed forward network with trainData.n_rows input nodes,
+   /* Construct a feed forward network with trainData.n_rows input nodes,
    * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
    * network structure looks like:
    *
@@ -302,30 +266,26 @@ void BuildNetworkOptimzer(MatType& trainData,
 
   OutputLayerType outputLayer;
 
-  OptimizerType conOptimizer0(trainData.n_rows, hiddenLayerSize);
-  OptimizerType conOptimizer1(1, hiddenLayerSize);
-  OptimizerType conOptimizer2(hiddenLayerSize, trainLabels.n_rows);
-
   FullConnection<
     decltype(inputLayer),
     decltype(hiddenLayer0),
-    decltype(conOptimizer0),
+    mlpack::ann::RMSPROP,
     decltype(weightInitRule)>
-    layerCon0(inputLayer, hiddenLayer0, conOptimizer0, weightInitRule);
+    layerCon0(inputLayer, hiddenLayer0, weightInitRule);
 
   FullConnection<
     decltype(biasLayer0),
     decltype(hiddenLayer0),
-    decltype(conOptimizer1),
+    mlpack::ann::RMSPROP,
     decltype(weightInitRule)>
-    layerCon1(biasLayer0, hiddenLayer0, conOptimizer1, weightInitRule);
+    layerCon1(biasLayer0, hiddenLayer0, weightInitRule);
 
   FullConnection<
       decltype(hiddenLayer0),
       decltype(hiddenLayer1),
-      decltype(conOptimizer2),
+      mlpack::ann::RMSPROP,
       decltype(weightInitRule)>
-      layerCon2(hiddenLayer0, hiddenLayer1, conOptimizer2, weightInitRule);
+      layerCon2(hiddenLayer0, hiddenLayer1, weightInitRule);
 
   auto module0 = std::tie(layerCon0, layerCon1);
   auto module1 = std::tie(layerCon2);
@@ -370,7 +330,6 @@ BOOST_AUTO_TEST_CASE(NetworkDecreasingErrorTest)
   // Vanilla neural net with logistic activation function.
   BuildNetworkOptimzer<RandomInitialization,
                        LogisticFunction,
-                       SteepestDescent<>,
                        BinaryClassificationLayer,
                        MeanSquaredErrorFunction>
       (dataset, labels, dataset, labels, 100, 50, randInitB);



More information about the mlpack-git mailing list