[mlpack-git] master: Refactor pooling connection to support 3rd order tensors. (609ee9c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon May 4 15:14:52 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6caf31a493719a3a5edf2fdcde9b0eef9e165944...6137e52d32c1338b28853afd059b67cf68a50270

>---------------------------------------------------------------

commit 609ee9c7ad83a7cbef088125f9138c40018e3423
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Mon May 4 19:19:33 2015 +0200

    Refactor pooling connection to support 3rd order tensors.


>---------------------------------------------------------------

609ee9c7ad83a7cbef088125f9138c40018e3423
 .../methods/ann/connections/pooling_connection.hpp | 239 +++++++++++++--------
 1 file changed, 151 insertions(+), 88 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/pooling_connection.hpp b/src/mlpack/methods/ann/connections/pooling_connection.hpp
index 76a3532..f8ef8dc 100644
--- a/src/mlpack/methods/ann/connections/pooling_connection.hpp
+++ b/src/mlpack/methods/ann/connections/pooling_connection.hpp
@@ -1,46 +1,47 @@
 /**
  * @file cnn_pooling_connection.hpp
  * @author Shangtong Zhang
+ * @author Marcus Edel
  *
- * Implementation of the pooling connection between input layer
- * and output layer for CNN.
+ * Implementation of the pooling connection between input layer and output layer
+ * for the convolutional neural network.
  */
 #ifndef __MLPACK_METHODS_ANN_CONNECTIONS_POOLING_CONNECTION_HPP
 #define __MLPACK_METHODS_ANN_CONNECTIONS_POOLING_CONNECTION_HPP
 
 #include <mlpack/core.hpp>
+#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/pooling/max_pooling.hpp>
+#include <mlpack/methods/ann/pooling_rules/max_pooling.hpp>
 #include <mlpack/methods/ann/connections/connection_traits.hpp>
 
 namespace mlpack{
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * Implementation of the pooling connection class for CNN.
- * The pooling connection connects
- * input layer with the output layer by pooling.
- * output = factor * pooling_value + bias
+ * Implementation of the pooling connection class for the convolutional neural
+ * network. The pooling connection connects input layer with the output layer
+ * using the specified pooling rule.
  *
  * @tparam InputLayerType Type of the connected input layer.
  * @tparam OutputLayerType Type of the connected output layer.
+ * @tparam PoolingRule Type of the pooling strategy.
  * @tparam OptimizerType Type of the optimizer used to update the weights.
- * @tparam PoolingRule Type of pooling strategy.
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ * @tparam DataType Type of data (arma::mat, arma::sp_mat or arma::cube).
  */
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
     typename PoolingRule = MaxPooling,
-    typename MatType = arma::mat
+    template<typename> class OptimizerType = SteepestDescent,
+    typename DataType = arma::cube
 >
 class PoolingConnection
 {
  public:
   /**
    * Create the PoolingConnection object using the specified input layer, output
-   * layer, optimizer, factor, bias and pooling strategy.
+   * layer, optimizer and pooling strategy.
    * The factor and bias is stored in @weights.
    *
    * @param InputLayerType The input layer which is connected with the output
@@ -52,80 +53,86 @@ class PoolingConnection
    */
   PoolingConnection(InputLayerType& inputLayer,
                     OutputLayerType& outputLayer,
-                    OptimizerType& optimizer,
-                    double factor = 1.0,
-                    double bias = 0,
                     PoolingRule pooling = PoolingRule()) :
-      inputLayer(inputLayer), outputLayer(outputLayer), optimizer(optimizer),
-      weights(2), pooling(pooling),
-      rawOutput(outputLayer.InputActivation().n_rows,
-                outputLayer.InputActivation().n_cols)
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(0),
+      weights(0),
+      pooling(pooling),
+      delta(inputLayer.Delta().n_rows, inputLayer.Delta().n_cols,
+            inputLayer.Delta().n_slices)
   {
-    delta = arma::zeros<MatType>(inputLayer.InputActivation().n_rows,
-                                 inputLayer.InputActivation().n_cols);
-    gradient = arma::zeros<arma::colvec>(2);
-    weights(0) = factor;
-    weights(1) = bias;
+    // Nothing to do here.
   }
 
   /**
-   * Ordinary feed forward pass of a neural network, 
-   * apply pooling to the neurons in the input layer.
+   * Ordinary feed forward pass of a neural network, apply pooling to the
+   * neurons (dense matrix) in the input layer.
    *
    * @param input Input data used for pooling.
    */
-  void FeedForward(const MatType& input)
+  template<typename eT>
+  void FeedForward(const arma::Mat<eT>& input)
   {
-    size_t r_step = input.n_rows / outputLayer.InputActivation().n_rows;
-    size_t c_step = input.n_cols / outputLayer.InputActivation().n_cols;
-    for (size_t j = 0; j < input.n_cols; j += c_step)
-    {
-      for (size_t i = 0; i < input.n_rows; i += r_step)
-      {
-        double value = 0;
-        pooling.pooling(input(arma::span(i, i + r_step -1),
-                              arma::span(j, j + c_step - 1)), value);
-        rawOutput(i / r_step, j / c_step) = value;
-      }
-    }
-    outputLayer.InputActivation() += rawOutput * weights(0) + weights(1);
+    Pooling(input, outputLayer.InputActivation());
+  }
+
+  /**
+   * Ordinary feed forward pass of a neural network, apply pooling to the
+   * neurons (3rd order tensor) in the input layer.
+   *
+   * @param input Input data used for pooling.
+   */
+  template<typename eT>
+  void FeedForward(const arma::Cube<eT>& input)
+  {
+    for (size_t s = 0; s < input.n_slices; s++)
+      Pooling(input.slice(s), outputLayer.InputActivation().slice(s));
   }
 
   /**
-   * Ordinary feed backward pass of a neural network.
-   * Apply unsampling to the error in output layer to 
-   * pass the error to input layer.
+   * Ordinary feed backward pass of a neural network. Apply unsampling to the
+   * error in output layer (dense matrix) to pass the error to input layer.
+   *
    * @param error The backpropagated error.
    */
-  void FeedBackward(const MatType& error)
+  template<typename eT>
+  void FeedBackward(const arma::Mat<eT>& error)
   {
-    gradient(1) = arma::sum(arma::sum(error));
-    gradient(0) = arma::sum(arma::sum(rawOutput % error));
-    MatType weightedError = error * weights(0);
-    size_t r_step = inputLayer.InputActivation().n_rows / error.n_rows;
-    size_t c_step = inputLayer.InputActivation().n_cols / error.n_cols;
-    const MatType& input = inputLayer.InputActivation();
-    MatType newError;
-    for (size_t j = 0; j < input.n_cols; j += c_step)
+    Unpooling(inputLayer.InputActivation(), error, inputLayer.Delta());
+  }
+
+  /**
+   * Ordinary feed backward pass of a neural network. Apply unsampling to the
+   * error in output layer (3rd order tensor) to pass the error to input layer.
+   *
+   * @param error The backpropagated error.
+   */
+  template<typename eT>
+  void FeedBackward(const arma::Cube<eT>& error)
+  {
+    for (size_t s = 0; s < error.n_slices; s++)
     {
-      for (size_t i = 0; i < input.n_rows; i += r_step)
-      {
-        const MatType& inputArea = input(arma::span(i, i + r_step -1),
-                                         arma::span(j, j + c_step - 1));
-        pooling.unpooling(inputArea,
-                          weightedError(i / r_step, j / c_step),
-                          newError);
-        delta(arma::span(i, i + r_step -1),
-              arma::span(j, j + c_step - 1)) = newError;
-      }
+      Unpooling(inputLayer.InputActivation().slice(s), error.slice(s),
+          delta.slice(s));
     }
-    inputLayer.Delta() += delta;
+  }
+
+  /*
+   * Calculate the gradient using the output delta and the input activation.
+   *
+   * @param gradient The calculated gradient.
+   */
+  template<typename GradientType>
+  void Gradient(GradientType& /* unused */)
+  {
+    // Nothing to do here.
   }
 
   //! Get the weights.
-  MatType& Weights() const { return weights; }
+  DataType& Weights() const { return *weights; }
   //! Modify the weights.
-  MatType& Weights() { return weights; }
+  DataType& Weights() { return *weights; }
 
   //! Get the input layer.
   InputLayerType& InputLayer() const { return inputLayer; }
@@ -138,19 +145,14 @@ class PoolingConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimizer.
-  OptimizerType& Optimzer() const { return optimizer; }
+  OptimizerType<DataType>& Optimzer() const { return *optimizer; }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return optimizer; }
+  OptimizerType<DataType>& Optimzer() { return *optimizer; }
 
   //! Get the passed error in backward propagation.
-  MatType& Delta() const { return delta; }
+  DataType& Delta() const { return delta; }
   //! Modify the passed error in backward propagation.
-  MatType& Delta() { return delta; }
-  
-  //! Get the gradient of weights.
-  MatType& Gradient() const { return gradient; }
-  //! Modify the delta of weights.
-  MatType& Gradient() { return gradient; }
+  DataType& Delta() { return delta; }
 
   //! Get the pooling strategy.
   PoolingRule& Pooling() const { return pooling; }
@@ -158,6 +160,59 @@ class PoolingConnection
   PoolingRule& Pooling() { return pooling; }
 
  private:
+  /**
+   * Apply pooling to the input and store the results.
+   *
+   * @param input The input to be apply the pooling rule.
+   * @param output The pooled result.
+   */
+  template<typename eT>
+  void Pooling(const arma::Mat<eT>& input, arma::Mat<eT>& output)
+  {
+    const size_t rStep = input.n_rows / outputLayer.LayerRows();
+    const size_t cStep = input.n_cols / outputLayer.LayerCols();
+
+    for (size_t j = 0; j < input.n_cols; j += cStep)
+    {
+      for (size_t i = 0; i < input.n_rows; i += rStep)
+      {
+        output(i / rStep, j / cStep) += pooling.Pooling(
+            input(arma::span(i, i + rStep -1), arma::span(j, j + cStep - 1)));
+      }
+    }
+  }
+
+  /**
+   * Apply unpooling to the input and store the results.
+   *
+   * @param input The input to be apply the unpooling rule.
+   * @param output The pooled result.
+   */
+  template<typename eT>
+  void Unpooling(const arma::Mat<eT>& input,
+                 const arma::Mat<eT>& error,
+                 arma::Mat<eT>& output)
+  {
+    const size_t rStep = input.n_rows / error.n_rows;
+    const size_t cStep = input.n_cols / error.n_cols;
+
+    arma::Mat<eT> unpooledError;
+    for (size_t j = 0; j < input.n_cols; j += cStep)
+    {
+      for (size_t i = 0; i < input.n_rows; i += rStep)
+      {
+        const arma::Mat<eT>& inputArea = input(arma::span(i, i + rStep -1),
+                                               arma::span(j, j + cStep - 1));
+
+        pooling.Unpooling(inputArea, error(i / rStep, j / cStep),
+            unpooledError);
+
+        output(arma::span(i, i + rStep - 1),
+            arma::span(j, j + cStep - 1)) += unpooledError;
+      }
+    }
+  }
+
   //! Locally-stored input layer.
   InputLayerType& inputLayer;
 
@@ -165,26 +220,34 @@ class PoolingConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimizer.
-  OptimizerType& optimizer;
+  OptimizerType<DataType>* optimizer;
 
-  //! Locally-stored weights, only two value, factor and bias.
-  arma::colvec weights;
-  
-  //! Locally-stored passed error in backward propagation.
-  MatType delta;
+  //! Locally-stored weight object.
+  DataType* weights;
 
   //! Locally-stored pooling strategy.
   PoolingRule pooling;
 
-  //! Locally-stored gradient of weights.
-  MatType gradient;
+  //! Locally-stored passed error in backward propagation.
+  DataType delta;
+}; // PoolingConnection class.
 
-  /**
-   * Locally-stored raw result of pooling,
-   * before multiplied by factor and added by bias.
-   * Cache it to speed up when performing backward propagation.
-   */
-  MatType rawOutput;
+//! Connection traits for the pooling connection.
+template<
+    typename InputLayerType,
+    typename OutputLayerType,
+    typename PoolingRule,
+    template<typename> class OptimizerType,
+    typename DataType
+>
+class ConnectionTraits<
+    PoolingConnection<InputLayerType, OutputLayerType, PoolingRule,
+    OptimizerType, DataType> >
+{
+ public:
+  static const bool IsSelfConnection = false;
+  static const bool IsFullselfConnection = false;
+  static const bool IsPoolingConnection = true;
 };
 
 }; // namespace ann



More information about the mlpack-git mailing list