[mlpack-git] master: Add bias connection, that works with convolutional neural networks. (f4d8e70)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Jun 11 17:10:18 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3ade9299e8f3c1e73ba30bff276b51813ede87b5...f4d8e7075ff00ac483b7aeaa01c3ffe4645e9bfc

>---------------------------------------------------------------

commit f4d8e7075ff00ac483b7aeaa01c3ffe4645e9bfc
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Thu Jun 11 23:10:06 2015 +0200

    Add bias connection, that works with convolutional neural networks.


>---------------------------------------------------------------

f4d8e7075ff00ac483b7aeaa01c3ffe4645e9bfc
 .../{full_connection.hpp => bias_connection.hpp}   | 160 ++++++++++-----------
 .../methods/ann/connections/full_connection.hpp    |   6 +-
 2 files changed, 77 insertions(+), 89 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/bias_connection.hpp
similarity index 63%
copy from src/mlpack/methods/ann/connections/full_connection.hpp
copy to src/mlpack/methods/ann/connections/bias_connection.hpp
index 0b851fb..eb0b3df 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/bias_connection.hpp
@@ -1,28 +1,31 @@
 /**
- * @file full_connection.hpp
+ * @file cnn_bias_connection.hpp
+ * @author Shangtong Zhang
  * @author Marcus Edel
  *
- * Implementation of the full connection class.
+ * Implementation of the connection between bias layer and other layer.
  */
-#ifndef __MLPACK_METHODS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
-#define __MLPACK_METHODS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
+#ifndef __MLPACK_METHODS_ANN_CONNECTIONS_BIAS_CONNECTION_HPP
+#define __MLPACK_METHODS_ANN_CONNECTIONS_BIAS_CONNECTION_HPP
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
+#include <mlpack/methods/ann/layer/layer_traits.hpp>
+#include <mlpack/methods/ann/connections/connection_traits.hpp>
 #include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * Implementation of the full connection class. The full connection connects
- * every neuron from the input layer with the output layer in a matrix
- * multiplicative way.
+ * Implementation of the bias connection class. The bias connection connects
+ * bias layer and other layer.
  *
- * @tparam InputLayerType Type of the connected input layer.
+ * @tparam InputLayerType Type of the connected input layer. It must be a bias
+ * layer.
  * @tparam OutputLayerType Type of the connected output layer.
  * @tparam OptimizerType Type of the optimizer used to update the weights.
- * @tparam WeightInitRule Rule used to initialize the weight matrix.
+ * @tparam WeightInitRule Rule used to initialize the weights matrix.
  * @tparam MatType Type of data (arma::mat or arma::sp_mat).
  */
 template<
@@ -32,22 +35,22 @@ template<
     class WeightInitRule = NguyenWidrowInitialization,
     typename MatType = arma::mat
 >
-class FullConnection
+class BiasConnection
 {
  public:
   /**
-   * Create the FullConnection object using the specified input layer, output
-   * layer, optimizer and weight initialization rule.
+   * Create the BiasConnection object using the specified input layer, output
+   * layer, optimizer and weight initialize rule.
    *
    * @param InputLayerType The input layer which is connected with the output
    * layer.
    * @param OutputLayerType The output layer which is connected with the input
    * layer.
    * @param OptimizerType The optimizer used to update the weight matrix.
-   * @param WeightInitRule The weights initialization rule used to initialize the
-   * weights matrix.
+   * @param WeightInitRule The weight initialize rule used to initialize the
+   * weight matrix.
    */
-  FullConnection(InputLayerType& inputLayer,
+  BiasConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
                  OptimizerType& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
@@ -56,24 +59,21 @@ class FullConnection
       optimizer(&optimizer),
       ownsOptimizer(false)
   {
-    weightInitRule.Initialize(weights, outputLayer.InputSize(),
-        inputLayer.LayerRows() * inputLayer.LayerCols() *
-        inputLayer.LayerSlices() * inputLayer.OutputMaps() /
-        outputLayer.LayerCols());
+    weightInitRule.Initialize(weights, outputLayer.OutputMaps(), 1);
   }
 
   /**
-   * Create the FullConnection object using the specified input layer, output
-   * layer and weight initialization rule.
+   * Create the BiasConnection object using the specified input layer, output
+   * layer and weight initialize rule.
    *
    * @param InputLayerType The input layer which is connected with the output
    * layer.
    * @param OutputLayerType The output layer which is connected with the input
    * layer.
-   * @param WeightInitRule The weights initialization rule used to initialize the
-   * weights matrix.
+   * @param WeightInitRule The weight initialize rule used to initialize the
+   * weight matrix.
    */
-  FullConnection(InputLayerType& inputLayer,
+  BiasConnection(InputLayerType& inputLayer,
                OutputLayerType& outputLayer,
                WeightInitRule weightInitRule = WeightInitRule()) :
     inputLayer(inputLayer),
@@ -83,19 +83,7 @@ class FullConnection
         inputLayer.OutputMaps() / outputLayer.LayerCols())),
     ownsOptimizer(true)
   {
-    weightInitRule.Initialize(weights, outputLayer.InputSize(),
-        inputLayer.LayerRows() * inputLayer.LayerCols() *
-        inputLayer.LayerSlices() * inputLayer.OutputMaps() /
-        outputLayer.LayerCols());
-  }
-
-  /**
-   * Delete the full connection object and its optimizer.
-   */
-  ~FullConnection()
-  {
-    if (ownsOptimizer)
-      delete optimizer;
+    weightInitRule.Initialize(weights, outputLayer.OutputMaps(), 1);
   }
 
   /**
@@ -108,50 +96,35 @@ class FullConnection
   template<typename eT>
   void FeedForward(const arma::Mat<eT>& input)
   {
-    outputLayer.InputActivation() += (weights * input);
+    Forward(outputLayer.InputActivation(), input);
   }
 
   /**
-   * Ordinary feed forward pass of a neural network, evaluating the function
-   * f(x) by propagating the activity forward through f using a 3rd order tensor
-   * as input.
+   * Ordinary feed backward pass of a neural network, calculating the function
+   * f(x) by propagating x backwards trough f. Using the results from the feed
+   * forward pass.
    *
-   * @param input Input data used for evaluating the specified activity function.
+   * @param error The backpropagated error.
    */
   template<typename eT>
-  void FeedForward(const arma::Cube<eT>& input)
+  void FeedBackward(const arma::Cube<eT>& error)
   {
-    MatType data(input.n_elem / outputLayer.LayerCols(),
-        outputLayer.LayerCols());
-
-    for (size_t s = 0, c = 0; s < input.n_slices / data.n_cols; s++)
+    delta = MatType(outputLayer.OutputMaps(), 1);
+    for (size_t s = 0; s < error.n_slices; s++)
     {
-      for (size_t i = 0; i < data.n_cols; i++, c++)
-      {
-        data.col(i).subvec(s * input.n_rows * input.n_cols, (s + 1) *
-            input.n_rows * input.n_cols - 1) = arma::vectorise(input.slice(c));
-      }
+      delta(s, 0) = weights(s, 0) * arma::accu(error.slice(s));
     }
-
-    outputLayer.InputActivation() += (weights * data);
   }
 
-  /**
-   * Ordinary feed backward pass of a neural network, calculating the function
-   * f(x) by propagating x backwards trough f. Using the results from the feed
-   * forward pass.
-   *
-   * @param error The backpropagated error.
-   */
-  template<typename ErrorType>
-  void FeedBackward(const ErrorType& error)
+  template<typename eT>
+  void FeedBackward(const arma::Mat<eT>& error)
   {
-    delta = (weights.t() * error);
+    delta = weights.t() * error;
   }
 
-  /*
-   * Calculate the gradient (dense matrix) using the output delta (dense matrix)
-   * and the input activation (dense matrix).
+  /**
+   * Calculate the gradient (dense matrix) using the output delta and the input
+   * activation.
    *
    * @param gradient The calculated gradient.
    */
@@ -170,7 +143,7 @@ class FullConnection
   template<typename eT>
   void Gradient(arma::Cube<eT>& gradient)
   {
-     GradientDelta(inputLayer.InputActivation(), gradient);
+    GradientDelta(outputLayer.Delta(), gradient);
   }
 
   //! Get the weights.
@@ -195,11 +168,11 @@ class FullConnection
 
   //! Get the detla.
   MatType& Delta() const { return delta; }
- //  //! Modify the delta.
+  //! Modify the delta.
   MatType& Delta() { return delta; }
 
  private:
-   /*
+  /*
    * Calculate the gradient using the output delta (3rd order tensor) and the
    * input activation (3rd order tensor).
    *
@@ -209,23 +182,11 @@ class FullConnection
   void GradientDelta(arma::Cube<eT>& /* unused */, arma::Cube<eT>& gradient)
   {
     gradient = arma::Cube<eT>(weights.n_rows, weights.n_cols, 1);
-    arma::Mat<eT> data = arma::Mat<eT>(outputLayer.Delta().n_cols,
-        inputLayer.InputActivation().n_elem / outputLayer.Delta().n_cols);
-
-    for (size_t s = 0, c = 0; s < inputLayer.InputActivation().n_slices /
-        data.n_rows; s++)
+    for (size_t s = 0; s < outputLayer.OutputMaps(); s++)
     {
-      for (size_t i = 0; i < data.n_rows; i++, c++)
-      {
-        data.row(i).subvec(s * inputLayer.InputActivation().n_rows *
-            inputLayer.InputActivation().n_cols, (s + 1) *
-            inputLayer.InputActivation().n_rows *
-            inputLayer.InputActivation().n_cols - 1) = arma::vectorise(
-                inputLayer.InputActivation().slice(c), 1);
-      }
+      gradient.slice(0)(s, 0) = arma::accu(outputLayer.Delta().slice(s)) *
+          inputLayer.InputActivation()(s, 0);
     }
-
-    gradient.slice(0) = outputLayer.Delta() * data / outputLayer.Delta().n_cols;
   }
 
   /*
@@ -241,6 +202,33 @@ class FullConnection
     Gradient(gradient.slice(0));
   }
 
+    /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f using a dense matrix as
+   * input.
+   *
+   * @param input Input data used for evaluating the specified activity function.
+   */
+  template<typename eT>
+  void Forward(const arma::Cube<eT>& /* unused */, const arma::Mat<eT>& input)
+  {
+    for (size_t s = 0; s < outputLayer.OutputMaps(); s++)
+      outputLayer.InputActivation().slice(s) += (weights(s, 0) * input(s, 0));
+  }
+
+  /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f using a dense matrix as
+   * input.
+   *
+   * @param input Input data used for evaluating the specified activity function.
+   */
+  template<typename eT>
+  void Forward(const arma::Mat<eT>& /* unused */, const arma::Mat<eT>& input)
+  {
+    outputLayer.InputActivation() += weights % input;
+  }
+
   //! Locally-stored weight object.
   MatType weights;
 
@@ -258,7 +246,7 @@ class FullConnection
 
   //! Locally-stored detla object that holds the calculated delta.
   MatType delta;
-}; // class FullConnection
+}; // class BiasConnection
 
 }; // namespace ann
 }; // namespace mlpack
diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/full_connection.hpp
index 0b851fb..a72c8ee 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/full_connection.hpp
@@ -108,7 +108,7 @@ class FullConnection
   template<typename eT>
   void FeedForward(const arma::Mat<eT>& input)
   {
-    outputLayer.InputActivation() += (weights * input);
+    outputLayer.InputActivation() += weights * input;
   }
 
   /**
@@ -146,7 +146,7 @@ class FullConnection
   template<typename ErrorType>
   void FeedBackward(const ErrorType& error)
   {
-    delta = (weights.t() * error);
+    delta = weights.t() * error;
   }
 
   /*
@@ -195,7 +195,7 @@ class FullConnection
 
   //! Get the detla.
   MatType& Delta() const { return delta; }
- //  //! Modify the delta.
+  //! Modify the delta.
   MatType& Delta() { return delta; }
 
  private:



More information about the mlpack-git mailing list