[mlpack-git] master: Refactor conv connection to support 3rd order tensors. (47dd33d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon May 4 15:14:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6caf31a493719a3a5edf2fdcde9b0eef9e165944...6137e52d32c1338b28853afd059b67cf68a50270

>---------------------------------------------------------------

commit 47dd33db7704fcb58effd8e8fcf3c58e1b69ef2f
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Mon May 4 19:31:29 2015 +0200

    Refactor conv connection to support 3rd order tensors.


>---------------------------------------------------------------

47dd33db7704fcb58effd8e8fcf3c58e1b69ef2f
 .../methods/ann/connections/conv_connection.hpp    | 172 +++++++++++++--------
 1 file changed, 108 insertions(+), 64 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/conv_connection.hpp b/src/mlpack/methods/ann/connections/conv_connection.hpp
index 835fdc5..5a0a5b0 100644
--- a/src/mlpack/methods/ann/connections/conv_connection.hpp
+++ b/src/mlpack/methods/ann/connections/conv_connection.hpp
@@ -1,33 +1,35 @@
 /**
  * @file cnn_conv_connection.hpp
  * @author Shangtong Zhang
+ * @author Marcus Edel
  *
- * Implementation of the convolutional connection 
- * between input layer and output layer.
+ * Implementation of the convolutional connection between input layer and output
+ * layer.
  */
 #ifndef __MLPACK_METHODS_ANN_CONNECTIONS_CONV_CONNECTION_HPP
 #define __MLPACK_METHODS_ANN_CONNECTIONS_CONV_CONNECTION_HPP
 
 #include <mlpack/core.hpp>
-#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/convolution/valid_convolution.hpp>
-#include <mlpack/methods/ann/convolution/full_convolution.hpp>
-#include <mlpack/methods/ann/connections/connection_traits.hpp>
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
+#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/convolution_rules/border_modes.hpp>
+#include <mlpack/methods/ann/convolution_rules/naive_convolution.hpp>
 
 namespace mlpack{
 namespace ann  /** Artificial Neural Network. */{
+
 /**
- * Implementation of the convolutional connection class.
- * The convolutional connection performs convolution between input layer
- * and output layer.
- * Convolution is applied to every neuron in input layer.
- * The kernel used for convolution is stored in @weights.
+ * Implementation of the convolutional connection class. The convolutional
+ * connection performs convolution between input layer and output layer.
+ * Convolution is applied to every neuron in input layer. The kernel used for
+ * convolution is stored in @weights.
  *
  * Users can design their own convolution rule (ForwardConvolutionRule)
- * to perform forward process. But once user-defined forward convolution is used, 
- * users have to design special BackwardConvolutionRule and GradientConvolutionRule
- * to perform backward process and calculate gradient corresponding to 
- * ForwardConvolutionRule, aimed to guarantee the correctness of error flow.
+ * to perform forward process. But once user-defined forward convolution is
+ * used, users have to design special BackwardConvolutionRule and
+ * GradientConvolutionRule to perform backward process and calculate gradient
+ * corresponding to ForwardConvolutionRule, aimed to guarantee the correctness
+ * of error flow.
  *
  * @tparam InputLayerType Type of the connected input layer.
  * @tparam OutputLayerType Type of the connected output layer.
@@ -36,77 +38,124 @@ namespace ann  /** Artificial Neural Network. */{
  * @tparam ForwardConvolutionRule Convolution to perform forward process.
  * @tparam BackwardConvolutionRule Convolution to perform backward process.
  * @tparam GradientConvolutionRule Convolution to calculate gradient.
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ * @tparam DataType Type of data (arma::mat, arma::sp_mat or arma::cube).
  */
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
-    class WeightInitRule = NguyenWidrowInitialization<>,
-    typename ForwardConvolutionRule = ValidConvolution,
-    typename BackwardConvolutionRule = ValidConvolution,
-    typename GradientConvolutionRule = RotatedKernelFullConvolution,
-    typename MatType = arma::mat
+    typename OptimizerType = SteepestDescent<>,
+    class WeightInitRule = RandomInitialization,
+    typename ForwardConvolutionRule = NaiveConvolution<ValidConvolution>,
+    typename BackwardConvolutionRule = NaiveConvolution<FullConvolution>,
+    typename GradientConvolutionRule = NaiveConvolution<ValidConvolution>,
+    typename DataType = arma::cube
 >
 class ConvConnection
 {
  public:
   /**
    * Create the ConvConnection object using the specified input layer, output
-   * layer, optimizer and weight initialize rule.
+   * layer, optimizer and weight initialization rule.
    *
    * @param InputLayerType The input layer which is connected with the output
    * layer.
    * @param OutputLayerType The output layer which is connected with the input
    * layer.
-   * @param OptimizerType The optimizer used to update the weights matrix.
-   * @param weightsRows The number of rows of convolutional kernel.
-   * @param weightsCols The number of cols of convolutional kernel.
-   * @param WeightInitRule The weights initialize rule used to initialize the
-   * weights matrix.
+   * @param filterRows The number of rows of the convolutional kernel.
+   * @param filterCols The number of cols of the convolutional kernel.
+   * @param OptimizerType The optimizer used to update the weight matrix.
+   * @param WeightInitRule The weights initialization rule used to initialize
+   * the weights matrix.
    */
   ConvConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
+                 const size_t filterRows,
+                 const size_t filterCols,
                  OptimizerType& optimizer,
-                 size_t weightsRows,
-                 size_t weightsCols,
                  WeightInitRule weightInitRule = WeightInitRule()) :
-      inputLayer(inputLayer), outputLayer(outputLayer), optimizer(optimizer)
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(&optimizer),
+      ownsOptimizer(false)
+  {
+    weightInitRule.Initialize(weights, filterRows, filterCols,
+        outputLayer.LayerSlices());
+  }
+
+  /**
+   * Create the ConvConnection object using the specified input layer, output
+   * layer, optimizer and weight initialization rule.
+   *
+   * @param InputLayerType The input layer which is connected with the output
+   * layer.
+   * @param OutputLayerType The output layer which is connected with the input
+   * layer.
+   * @param filterRows The number of rows of the convolutional kernel.
+   * @param filterCols The number of cols of the convolutional kernel.
+   * @param WeightInitRule The weights initialization rule used to initialize
+   * the weights matrix.
+   */
+  ConvConnection(InputLayerType& inputLayer,
+                 OutputLayerType& outputLayer,
+                 const size_t filterRows,
+                 const size_t filterCols,
+                 WeightInitRule weightInitRule = WeightInitRule()) :
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(new OptimizerType()),
+      ownsOptimizer(true)
   {
-    weightInitRule.Initialize(weights, weightsRows, weightsCols);
-    gradient = arma::zeros<MatType>(weightsRows, weightsCols);
+    weightInitRule.Initialize(weights, filterRows, filterCols,
+        outputLayer.LayerSlices());
   }
 
   /**
-   * Ordinary feed forward pass of a neural network, 
-   * Apply convolution to every neuron in input layer and
-   * put the output in the output layer.
+   * Delete the conv connection object and its optimizer.
    */
-  void FeedForward(const MatType& input)
+  ~ConvConnection()
   {
-    MatType output(outputLayer.InputActivation().n_rows,
-                   outputLayer.InputActivation().n_cols);
-    ForwardConvolutionRule::conv(input, weights, output);
+    if (ownsOptimizer)
+      delete optimizer;
+  }
+
+  /**
+   * Ordinary feed forward pass of a neural network. Apply convolution to every
+   * neuron in input layer and put the output in the output layer.
+   */
+  template<typename InputType>
+  void FeedForward(const InputType& input)
+  {
+    DataType output;
+    ForwardConvolutionRule::Convolution(input, weights, output);
     outputLayer.InputActivation() += output;
   }
 
   /**
-   * Ordinary feed backward pass of a neural network.
-   * Pass the error from output layer to input layer and 
-   * calculate the delta of kernel weights.
+   * Ordinary feed backward pass of a neural network. Pass the error from output
+   * layer to input layer and calculate the delta of kernel weights.
+   *
    * @param error The backpropagated error.
    */
-  void FeedBackward(const MatType& error)
+  void FeedBackward(const DataType& error)
+  {
+    BackwardConvolutionRule::conv(weights, error, delta);
+  }
+
+  /*
+   * Calculate the gradient using the output delta and the input activation.
+   *
+   * @param gradient The calculated gradient.
+   */
+  void Gradient(DataType& gradient)
   {
-    BackwardConvolutionRule::conv(inputLayer.InputActivation(), error, gradient);
-    GradientConvolutionRule::conv(weights, error, delta);
-    inputLayer.Delta() += delta;
+    GradientConvolutionRule::Convolution(inputLayer.InputActivation(),
+        outputLayer.Delta(), gradient);
   }
 
   //! Get the convolution kernel.
-  MatType& Weights() const { return weights; }
+  DataType& Weights() const { return weights; }
   //! Modify the convolution kernel.
-  MatType& Weights() { return weights; }
+  DataType& Weights() { return weights; }
 
   //! Get the input layer.
   InputLayerType& InputLayer() const { return inputLayer; }
@@ -119,23 +168,18 @@ class ConvConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimzer.
-  OptimizerType& Optimzer() const { return optimizer; }
+  OptimizerType& Optimzer() const { return *optimizer; }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return optimizer; }
+  OptimizerType& Optimzer() { return *optimizer; }
 
   //! Get the passed error in backward propagation.
-  MatType& Delta() const { return delta; }
+  DataType& Delta() const { return delta; }
   //! Modify the passed error in backward propagation.
-  MatType& Delta() { return delta; }
-  
-  //! Get the gradient of kernel.
-  MatType& Gradient() const { return gradient; }
-  //! Modify the gradient of kernel.
-  MatType& Gradient() { return gradient; }
+  DataType& Delta() { return delta; }
 
  private:
   //! Locally-stored kernel weights.
-  MatType weights;
+  DataType weights;
 
   //! Locally-stored inputlayer.
   InputLayerType& inputLayer;
@@ -144,13 +188,13 @@ class ConvConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimizer.
-  OptimizerType& optimizer;
+  OptimizerType* optimizer;
 
-  //! Locally-stored passed error in backward propagation.
-  MatType delta;
+  //! Parameter that indicates if the class owns a optimizer object.
+  bool ownsOptimizer;
 
-  //! Locally-stored gradient of kernel weights.
-  MatType gradient;
+  //! Locally-stored passed error in backward propagation.
+  DataType delta;
 };// class ConvConnection
 
 }; // namespace ann



More information about the mlpack-git mailing list