[mlpack-git] master: Refactor feedforward network test for new network API. (967adc4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat Aug 29 09:03:18 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3413a77502387c942a12ac47a7e2cc966ed2ddcd...967adc4746d7e638422cdf1c373d2865ff4e8d4c

>---------------------------------------------------------------

commit 967adc4746d7e638422cdf1c373d2865ff4e8d4c
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Aug 29 15:01:34 2015 +0200

    Refactor feedforward network test for new network API.


>---------------------------------------------------------------

967adc4746d7e638422cdf1c373d2865ff4e8d4c
 src/mlpack/methods/ann/ffn.hpp                 |   9 +-
 src/mlpack/methods/ann/optimizer/rmsprop.hpp   |  10 +-
 src/mlpack/methods/ann/trainer/trainer.hpp     |  15 +-
 src/mlpack/tests/activation_functions_test.cpp | 333 +++++++++++--------------
 4 files changed, 160 insertions(+), 207 deletions(-)

diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp
index 4eab5fe..05d430d 100644
--- a/src/mlpack/methods/ann/ffn.hpp
+++ b/src/mlpack/methods/ann/ffn.hpp
@@ -52,7 +52,6 @@ class FFN
      * @param input Input data used to evaluate the network.
      * @param target Target data used to calculate the network error.
      * @param error The calulated error of the output layer.
-     * @tparam DataType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
     template <typename InputType, typename TargetType, typename ErrorType>
     void FeedForward(const InputType& input,
@@ -68,7 +67,6 @@ class FFN
      * error of the output layer.
      *
      * @param error The calulated error of the output layer.
-     * @tparam DataType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
     template <typename InputType, typename ErrorType>
     void FeedBackward(const InputType& /* unused */, const ErrorType& error)
@@ -94,7 +92,6 @@ class FFN
      *
      * @param input Input data used to evaluate the network.
      * @param output Output data used to store the output activation
-     * @tparam DataType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
     template <typename DataType>
     void Predict(const DataType& input, DataType& output)
@@ -113,7 +110,6 @@ class FFN
      * @param input Input data used to evaluate the trained network.
      * @param target Target data used to calculate the network error.
      * @param error The calulated error of the output layer.
-     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
     template <typename InputType, typename TargetType, typename ErrorType>
     double Evaluate(const InputType& input,
@@ -243,8 +239,8 @@ class FFN
                        const std::tuple<Tp...>& t)
     {
       // Calculate and store the output error.
-      outputLayer.CalculateError(std::get<sizeof...(Tp) - 1>(t).OutputParameter(),
-          target, error);
+      outputLayer.CalculateError(
+          std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target, error);
 
       // Masures the network's performance with the specified performance
       // function.
@@ -394,7 +390,6 @@ class FFN
     bool deterministic;
 }; // class FFN
 
-
 //! Network traits for the FFN network.
 template <
   typename LayerTypes,
diff --git a/src/mlpack/methods/ann/optimizer/rmsprop.hpp b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
index a29be5a..063236d 100644
--- a/src/mlpack/methods/ann/optimizer/rmsprop.hpp
+++ b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
@@ -80,13 +80,12 @@ class RMSPROP
   {
     if (gradient.n_elem != 0)
     {
-      DataType outputGradient;
-      function.Gradient(outputGradient);
+      DataType outputGradient = function.Gradient();
       gradient += outputGradient;
     }
     else
     {
-      function.Gradient(gradient);
+      gradient = function.Gradient();
     }
   }
 
@@ -98,6 +97,11 @@ class RMSPROP
     gradient.zeros();
   }
 
+  //! Get the gradient.
+  DataType& Gradient() const { return gradient; }
+  //! Modify the gradient.
+  DataType& Gradient() { return gradient; }
+
  private:
   /**
    * Optimize the given function using RmsProp.
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index 1f1aa4d..cbb6521 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -24,12 +24,10 @@ namespace ann /** Artificial Neural Network. */ {
  * @tparam NetworkType The type of network which should be trained and
  * evaluated.
  * @tparam MaType Type of the error type (arma::mat or arma::sp_mat).
- * @tparam VecType Type of error type (arma::colvec, arma::mat or arma::sp_mat).
  */
 template<
   typename NetworkType,
-  typename MatType = arma::mat,
-  typename VecType = arma::colvec
+  typename MatType = arma::mat
 >
 class Trainer
 {
@@ -86,6 +84,8 @@ class Trainer
           ElementCount(trainingData) - 1, ElementCount(trainingData));
       epoch = 0;
 
+      size_t foo = 0;
+
       while(true)
       {
         if (shuffle)
@@ -147,7 +147,7 @@ class Trainer
             Element(target, index(i)), error);
 
         trainingError += net.Error();
-        net.FeedBackward(error);
+        net.FeedBackward(Element(data, index(i)), error);
 
         if (((i + 1) % batchSize) == 0)
           net.ApplyGradients();
@@ -189,10 +189,10 @@ class Trainer
      */
     template<typename eT>
     typename std::enable_if<!NetworkTraits<NetworkType>::IsCNN,
-        arma::Col<eT> >::type
+        arma::Mat<eT> >::type
     Element(arma::Mat<eT>& input, const size_t colNum)
     {
-      return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
+      return arma::Mat<eT>(input.colptr(colNum), input.n_rows, 1, false, true);
     }
 
     /*
@@ -248,8 +248,7 @@ class Trainer
     NetworkType& net;
 
     //! The current network error of a single input.
-    typename std::conditional<NetworkTraits<NetworkType>::IsFNN,
-        VecType, MatType>::type error;
+    MatType error;
 
     //! The current epoch if maxEpochs is set.
     size_t epoch;
diff --git a/src/mlpack/tests/activation_functions_test.cpp b/src/mlpack/tests/activation_functions_test.cpp
index c852684..f1d8557 100644
--- a/src/mlpack/tests/activation_functions_test.cpp
+++ b/src/mlpack/tests/activation_functions_test.cpp
@@ -12,14 +12,15 @@
 #include <mlpack/methods/ann/activation_functions/tanh_function.hpp>
 #include <mlpack/methods/ann/activation_functions/rectifier_function.hpp>
 
-#include <mlpack/methods/ann/ffnn.hpp>
+#include <mlpack/methods/ann/ffn.hpp>
 #include <mlpack/methods/ann/init_rules/random_init.hpp>
-#include <mlpack/methods/ann/layer/neuron_layer.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
+#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
+
 #include <mlpack/methods/ann/layer/bias_layer.hpp>
-#include <mlpack/methods/ann/layer/multiclass_classification_layer.hpp>
-#include <mlpack/methods/ann/connections/full_connection.hpp>
-#include <mlpack/methods/ann/connections/self_connection.hpp>
-#include <mlpack/methods/ann/optimizer/irpropp.hpp>
+#include <mlpack/methods/ann/layer/linear_layer.hpp>
+#include <mlpack/methods/ann/layer/base_layer.hpp>
+#include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -198,188 +199,142 @@ BOOST_AUTO_TEST_CASE(RectifierFunctionTest)
       desiredDerivatives);
 }
 
-// /*
-//  * Implementation of the numerical gradient checking.
-//  *
-//  * @param input Input data used for evaluating the network.
-//  * @param target Target data used to calculate the network error.
-//  * @param perturbation Constant perturbation value.
-//  * @param threshold Threshold used as bounding check.
-//  *
-//  * @tparam ActivationFunction Activation function used for the gradient check.
-//  */
-// template<class ActivationFunction>
-// void CheckGradientNumericallyCorrect(const arma::colvec input,
-//                                      const arma::colvec target,
-//                                      const double perturbation,
-//                                      const double threshold)
-// {
-//   // Specify the structure of the feed forward neural network.
-//   RandomInitialization randInit(-0.5, 0.5);
-//   arma::colvec error;
-
-//   NeuronLayer<ActivationFunction> inputLayer(input.n_elem);
-
-//   BiasLayer<> biasLayer0(1);
-//   BiasLayer<> biasLayer1(1);
-//   BiasLayer<> biasLayer2(1);
-
-//   NeuronLayer<ActivationFunction> hiddenLayer0(4);
-//   NeuronLayer<ActivationFunction> hiddenLayer1(2);
-//   NeuronLayer<ActivationFunction> hiddenLayer2(target.n_elem);
-
-//   iRPROPp< > conOptimizer0(input.n_elem, hiddenLayer0.InputSize());
-//   iRPROPp< > conOptimizer1(1, 4);
-//   iRPROPp< > conOptimizer2(4, 2);
-//   iRPROPp< > conOptimizer3(1, 2);
-//   iRPROPp< > conOptimizer4(2, target.n_elem);
-//   iRPROPp< > conOptimizer5(1, target.n_elem);
-
-//   ClassificationLayer outputLayer;
-
-//   FullConnection<
-//       decltype(inputLayer),
-//       decltype(hiddenLayer0),
-//       decltype(conOptimizer0),
-//       decltype(randInit)>
-//       layerCon0(inputLayer, hiddenLayer0, conOptimizer0, randInit);
-
-//   FullConnection<
-//     decltype(biasLayer0),
-//     decltype(hiddenLayer0),
-//     decltype(conOptimizer1),
-//     decltype(randInit)>
-//     layerCon1(biasLayer0, hiddenLayer0, conOptimizer1, randInit);
-
-//   FullConnection<
-//       decltype(hiddenLayer0),
-//       decltype(hiddenLayer1),
-//       decltype(conOptimizer2),
-//       decltype(randInit)>
-//       layerCon2(hiddenLayer0, hiddenLayer1, conOptimizer2, randInit);
-
-//   FullConnection<
-//     decltype(biasLayer1),
-//     decltype(hiddenLayer1),
-//     decltype(conOptimizer3),
-//     decltype(randInit)>
-//     layerCon3(biasLayer1, hiddenLayer1, conOptimizer3, randInit);
-
-//   FullConnection<
-//       decltype(hiddenLayer1),
-//       decltype(hiddenLayer2),
-//       decltype(conOptimizer4),
-//       decltype(randInit)>
-//       layerCon4(hiddenLayer1, hiddenLayer2, conOptimizer4, randInit);
-
-//   FullConnection<
-//     decltype(biasLayer2),
-//     decltype(hiddenLayer2),
-//     decltype(conOptimizer5),
-//     decltype(randInit)>
-//     layerCon5(biasLayer2, hiddenLayer2, conOptimizer5, randInit);
-
-//   auto module0 = std::tie(layerCon0, layerCon1);
-//   auto module1 = std::tie(layerCon2, layerCon3);
-//   auto module2 = std::tie(layerCon4, layerCon5);
-//   auto modules = std::tie(module0, module1, module2);
-
-//   FFNN<decltype(modules), decltype(outputLayer)> net(modules, outputLayer);
-
-//   // Initialize the feed forward neural network.
-//   net.FeedForward(input, target, error);
-//   net.FeedBackward(error);
-
-//   std::vector<std::reference_wrapper<
-//       FullConnection<
-//       decltype(inputLayer),
-//       decltype(hiddenLayer0),
-//       decltype(conOptimizer0),
-//       decltype(randInit)> > > layer {layerCon0, layerCon2, layerCon4};
-
-//   std::vector<arma::mat> gradient {
-//       hiddenLayer0.Delta() * inputLayer.InputActivation().t(),
-//       hiddenLayer1.Delta() * hiddenLayer0.InputActivation().t(),
-//       hiddenLayer2.Delta() * hiddenLayer1.InputActivation().t() };
-
-//   double weight, mLoss, pLoss, dW, e;
-
-//   for (size_t l = 0; l < layer.size(); ++l)
-//   {
-//     for (size_t i = 0; i < layer[l].get().Weights().n_rows; ++i)
-//     {
-//       for (size_t j = 0; j < layer[l].get().Weights().n_cols; ++j)
-//       {
-//         // Store original weight.
-//         weight = layer[l].get().Weights()(i, j);
-
-//         // Add negative perturbation and compute error.
-//         layer[l].get().Weights().at(i, j) -= perturbation;
-//         net.FeedForward(input, target, error);
-//         mLoss = arma::as_scalar(0.5 * arma::sum(arma::pow(error, 2)));
-
-//         // Add positive perturbation and compute error.
-//         layer[l].get().Weights().at(i, j) += (2 * perturbation);
-//         net.FeedForward(input, target, error);
-//         pLoss = arma::as_scalar(0.5 * arma::sum(arma::pow(error, 2)));
-
-//         // Compute symmetric difference.
-//         dW = (pLoss - mLoss) / (2 * perturbation);
-//         e = std::abs(dW - gradient[l].at(i, j));
-
-//         bool b = e < threshold;
-//         BOOST_REQUIRE_EQUAL(b, 1);
-
-//         // Restore original weight.
-//         layer[l].get().Weights().at(i, j) = weight;
-//       }
-//     }
-//   }
-// }
-
-// /**
-//  * The following test implements numerical gradient checking. It computes the
-//  * numerical gradient, a numerical approximation of the partial derivative of J
-//  * with respect to the i-th input argument, evaluated at g. The numerical
-//  * gradient should be approximately the partial derivative of J with respect to
-//  * g(i).
-//  *
-//  * Given a function g(\theta) that is supposedly computing:
-//  *
-//  * @f[
-//  * \frac{\partial}{\partial \theta} J(\theta)
-//  * @f]
-//  *
-//  * we can now numerically verify its correctness by checking:
-//  *
-//  * @f[
-//  * g(\theta) \approx \frac{J(\theta + eps) - J(\theta - eps)}{2 * eps}
-//  * @f]
-//  */
-// BOOST_AUTO_TEST_CASE(GradientNumericallyCorrect)
-// {
-//   // Initialize dataset.
-//   const arma::colvec input = arma::randu<arma::colvec>(10);
-//   const arma::colvec target("0 1;");
-
-//   // Perturbation and threshold constant.
-//   const double perturbation = 1e-6;
-//   const double threshold = 1e-7;
-
-//   CheckGradientNumericallyCorrect<LogisticFunction>(input, target,
-//       perturbation, threshold);
-
-//   CheckGradientNumericallyCorrect<IdentityFunction>(input, target,
-//       perturbation, threshold);
-
-//   CheckGradientNumericallyCorrect<RectifierFunction>(input, target,
-//       perturbation, threshold);
-
-//   CheckGradientNumericallyCorrect<SoftsignFunction>(input, target,
-//       perturbation, threshold);
-
-//   CheckGradientNumericallyCorrect<TanhFunction>(input, target,
-//       perturbation, threshold);
-// }
+/*
+ * Implementation of the numerical gradient checking.
+ *
+ * @param input Input data used for evaluating the network.
+ * @param target Target data used to calculate the network error.
+ * @param perturbation Constant perturbation value.
+ * @param threshold Threshold used as bounding check.
+ *
+ * @tparam ActivationFunction Activation function used for the gradient check.
+ */
+template<class ActivationFunction>
+void CheckGradientNumericallyCorrect(const arma::mat input,
+                                     const arma::mat target,
+                                     const double perturbation,
+                                     const double threshold)
+{
+  // Specify the structure of the feed forward neural network.
+  RandomInitialization randInit(-0.5, 0.5);
+  arma::mat error;
+
+  // Number of hidden layer units.
+  const size_t hiddenLayerSize = 4;
+
+  LinearLayer<mlpack::ann::RMSPROP, RandomInitialization> linearLayer0(
+        input.n_rows, hiddenLayerSize, randInit);
+  BiasLayer<> biasLayer0(hiddenLayerSize);
+  BaseLayer<ActivationFunction> baseLayer0;
+
+  LinearLayer<mlpack::ann::RMSPROP, RandomInitialization> linearLayer1(
+         hiddenLayerSize, hiddenLayerSize, randInit);
+  BiasLayer<> biasLayer1(hiddenLayerSize);
+  BaseLayer<ActivationFunction> baseLayer1;
+
+  LinearLayer<mlpack::ann::RMSPROP, RandomInitialization> linearLayer2(
+         hiddenLayerSize, target.n_rows, randInit);
+  BiasLayer<> biasLayer2(target.n_rows);
+  BaseLayer<ActivationFunction> baseLayer2;
+
+  BinaryClassificationLayer classOutputLayer;
+
+  auto modules = std::tie(linearLayer0, biasLayer0, baseLayer0,
+                          linearLayer1, biasLayer1, baseLayer1,
+                          linearLayer2, biasLayer2, baseLayer2);
+
+  FFN<decltype(modules), decltype(classOutputLayer), MeanSquaredErrorFunction>
+      net(modules, classOutputLayer);
+
+  // Initialize the feed forward neural network.
+  net.FeedForward(input, target, error);
+  net.FeedBackward(input, error);
+
+  std::vector<std::reference_wrapper<decltype(linearLayer0)> > layer {
+         linearLayer0, linearLayer1, linearLayer2 };
+
+  std::vector<arma::mat> gradient {linearLayer0.Gradient(),
+                                   linearLayer1.Gradient(),
+                                   linearLayer2.Gradient()};
+
+  double weight, mLoss, pLoss, dW, e;
+
+  for (size_t l = 0; l < layer.size(); ++l)
+  {
+    for (size_t i = 0; i < layer[l].get().Weights().n_rows; ++i)
+    {
+      for (size_t j = 0; j < layer[l].get().Weights().n_cols; ++j)
+      {
+        // Store original weight.
+        weight = layer[l].get().Weights()(i, j);
+
+        // Add negative perturbation and compute error.
+        layer[l].get().Weights().at(i, j) -= perturbation;
+        net.FeedForward(input, target, error);
+        mLoss = arma::as_scalar(0.5 * arma::sum(arma::pow(error, 2)));
+
+        // Add positive perturbation and compute error.
+        layer[l].get().Weights().at(i, j) += (2 * perturbation);
+        net.FeedForward(input, target, error);
+        pLoss = arma::as_scalar(0.5 * arma::sum(arma::pow(error, 2)));
+
+        // Compute symmetric difference.
+        dW = (pLoss - mLoss) / (2 * perturbation);
+        e = std::abs(dW - gradient[l].at(i, j));
+
+        bool b = e < threshold;
+        BOOST_REQUIRE_EQUAL(b, 1);
+
+        // Restore original weight.
+        layer[l].get().Weights().at(i, j) = weight;
+      }
+    }
+  }
+}
+
+/**
+ * The following test implements numerical gradient checking. It computes the
+ * numerical gradient, a numerical approximation of the partial derivative of J
+ * with respect to the i-th input argument, evaluated at g. The numerical
+ * gradient should be approximately the partial derivative of J with respect to
+ * g(i).
+ *
+ * Given a function g(\theta) that is supposedly computing:
+ *
+ * @f[
+ * \frac{\partial}{\partial \theta} J(\theta)
+ * @f]
+ *
+ * we can now numerically verify its correctness by checking:
+ *
+ * @f[
+ * g(\theta) \approx \frac{J(\theta + eps) - J(\theta - eps)}{2 * eps}
+ * @f]
+ */
+BOOST_AUTO_TEST_CASE(GradientNumericallyCorrect)
+{
+  // Initialize dataset.
+  const arma::colvec input = arma::randu<arma::colvec>(10);
+  const arma::colvec target("0 1;");
+
+  // Perturbation and threshold constant.
+  const double perturbation = 1e-6;
+  const double threshold = 1e-5;
+
+  CheckGradientNumericallyCorrect<LogisticFunction>(input, target,
+      perturbation, threshold);
+
+  CheckGradientNumericallyCorrect<IdentityFunction>(input, target,
+      perturbation, threshold);
+
+  CheckGradientNumericallyCorrect<RectifierFunction>(input, target,
+      perturbation, threshold);
+
+  CheckGradientNumericallyCorrect<SoftsignFunction>(input, target,
+      perturbation, threshold);
+
+  CheckGradientNumericallyCorrect<TanhFunction>(input, target,
+      perturbation, threshold);
+}
 
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list