[mlpack-git] master: add two layers (44308fa)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat Apr 25 07:54:37 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/86f647ca937380cf2ea2569ba5735dcdcc659730...fbd6b1f878ec3b2fa365254a22daf3add743ee51

>---------------------------------------------------------------

commit 44308faa775aa5de7bca1dbb510c86833f17703c
Author: HurricaneTong <HurricaneTong at HurricaneTong.local>
Date:   Sat Apr 25 14:13:22 2015 +0800

    add two layers


>---------------------------------------------------------------

44308faa775aa5de7bca1dbb510c86833f17703c
 .../ann/connections/cnn_conv_connection.hpp        | 159 +++++++++++++++++++++
 src/mlpack/methods/ann/layer/neuron_layer.hpp      | 158 ++++++++++++++------
 ..._classification_layer.hpp => one_hot_layer.hpp} |  43 +++---
 3 files changed, 294 insertions(+), 66 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/cnn_conv_connection.hpp b/src/mlpack/methods/ann/connections/cnn_conv_connection.hpp
new file mode 100644
index 0000000..835fdc5
--- /dev/null
+++ b/src/mlpack/methods/ann/connections/cnn_conv_connection.hpp
@@ -0,0 +1,159 @@
+/**
+ * @file cnn_conv_connection.hpp
+ * @author Shangtong Zhang
+ *
+ * Implementation of the convolutional connection 
+ * between input layer and output layer.
+ */
+#ifndef __MLPACK_METHODS_ANN_CONNECTIONS_CONV_CONNECTION_HPP
+#define __MLPACK_METHODS_ANN_CONNECTIONS_CONV_CONNECTION_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
+#include <mlpack/methods/ann/convolution/valid_convolution.hpp>
+#include <mlpack/methods/ann/convolution/full_convolution.hpp>
+#include <mlpack/methods/ann/connections/connection_traits.hpp>
+
+namespace mlpack{
+namespace ann  /** Artificial Neural Network. */{
+/**
+ * Implementation of the convolutional connection class.
+ * The convolutional connection performs convolution between input layer
+ * and output layer.
+ * Convolution is applied to every neuron in input layer.
+ * The kernel used for convolution is stored in @weights.
+ *
+ * Users can design their own convolution rule (ForwardConvolutionRule)
+ * to perform forward process. But once user-defined forward convolution is used, 
+ * users have to design special BackwardConvolutionRule and GradientConvolutionRule
+ * to perform backward process and calculate gradient corresponding to 
+ * ForwardConvolutionRule, aimed to guarantee the correctness of error flow.
+ *
+ * @tparam InputLayerType Type of the connected input layer.
+ * @tparam OutputLayerType Type of the connected output layer.
+ * @tparam OptimizerType Type of the optimizer used to update the weights.
+ * @tparam WeightInitRule Rule used to initialize the weights matrix.
+ * @tparam ForwardConvolutionRule Convolution to perform forward process.
+ * @tparam BackwardConvolutionRule Convolution to perform backward process.
+ * @tparam GradientConvolutionRule Convolution to calculate gradient.
+ * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ */
+template<
+    typename InputLayerType,
+    typename OutputLayerType,
+    typename OptimizerType,
+    class WeightInitRule = NguyenWidrowInitialization<>,
+    typename ForwardConvolutionRule = ValidConvolution,
+    typename BackwardConvolutionRule = ValidConvolution,
+    typename GradientConvolutionRule = RotatedKernelFullConvolution,
+    typename MatType = arma::mat
+>
+class ConvConnection
+{
+ public:
+  /**
+   * Create the ConvConnection object using the specified input layer, output
+   * layer, optimizer and weight initialize rule.
+   *
+   * @param InputLayerType The input layer which is connected with the output
+   * layer.
+   * @param OutputLayerType The output layer which is connected with the input
+   * layer.
+   * @param OptimizerType The optimizer used to update the weights matrix.
+   * @param weightsRows The number of rows of convolutional kernel.
+   * @param weightsCols The number of cols of convolutional kernel.
+   * @param WeightInitRule The weights initialize rule used to initialize the
+   * weights matrix.
+   */
+  ConvConnection(InputLayerType& inputLayer,
+                 OutputLayerType& outputLayer,
+                 OptimizerType& optimizer,
+                 size_t weightsRows,
+                 size_t weightsCols,
+                 WeightInitRule weightInitRule = WeightInitRule()) :
+      inputLayer(inputLayer), outputLayer(outputLayer), optimizer(optimizer)
+  {
+    weightInitRule.Initialize(weights, weightsRows, weightsCols);
+    gradient = arma::zeros<MatType>(weightsRows, weightsCols);
+  }
+  
+  /**
+   * Ordinary feed forward pass of a neural network, 
+   * Apply convolution to every neuron in input layer and
+   * put the output in the output layer.
+   */
+  void FeedForward(const MatType& input)
+  {
+    MatType output(outputLayer.InputActivation().n_rows,
+                   outputLayer.InputActivation().n_cols);
+    ForwardConvolutionRule::conv(input, weights, output);
+    outputLayer.InputActivation() += output;
+  }
+  
+  /**
+   * Ordinary feed backward pass of a neural network.
+   * Pass the error from output layer to input layer and 
+   * calculate the delta of kernel weights.
+   * @param error The backpropagated error.
+   */
+  void FeedBackward(const MatType& error)
+  {
+    BackwardConvolutionRule::conv(inputLayer.InputActivation(), error, gradient);
+    GradientConvolutionRule::conv(weights, error, delta);
+    inputLayer.Delta() += delta;
+  }
+  
+  //! Get the convolution kernel.
+  MatType& Weights() const { return weights; }
+  //! Modify the convolution kernel.
+  MatType& Weights() { return weights; }
+  
+  //! Get the input layer.
+  InputLayerType& InputLayer() const { return inputLayer; }
+  //! Modify the input layer.
+  InputLayerType& InputLayer() { return inputLayer; }
+  
+  //! Get the output layer.
+  OutputLayerType& OutputLayer() const { return outputLayer; }
+  //! Modify the output layer.
+  OutputLayerType& OutputLayer() { return outputLayer; }
+  
+  //! Get the optimzer.
+  OptimizerType& Optimzer() const { return optimizer; }
+  //! Modify the optimzer.
+  OptimizerType& Optimzer() { return optimizer; }
+  
+  //! Get the passed error in backward propagation.
+  MatType& Delta() const { return delta; }
+  //! Modify the passed error in backward propagation.
+  MatType& Delta() { return delta; }
+  
+  //! Get the gradient of kernel.
+  MatType& Gradient() const { return gradient; }
+  //! Modify the gradient of kernel.
+  MatType& Gradient() { return gradient; }
+  
+ private:
+  //! Locally-stored kernel weights.
+  MatType weights;
+  
+  //! Locally-stored inputlayer.
+  InputLayerType& inputLayer;
+  
+  //! Locally-stored outputlayer.
+  OutputLayerType& outputLayer;
+  
+  //! Locally-stored optimizer.
+  OptimizerType& optimizer;
+  
+  //! Locally-stored passed error in backward propagation.
+  MatType delta;
+  
+  //! Locally-stored gradient of kernel weights.
+  MatType gradient;
+};// class ConvConnection
+    
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/ann/layer/neuron_layer.hpp b/src/mlpack/methods/ann/layer/neuron_layer.hpp
index 4b0105a..afba596 100644
--- a/src/mlpack/methods/ann/layer/neuron_layer.hpp
+++ b/src/mlpack/methods/ann/layer/neuron_layer.hpp
@@ -1,9 +1,10 @@
 /**
  * @file neuron_layer.hpp
  * @author Marcus Edel
+ * @author Shangtong Zhang
  *
  * Definition of the NeuronLayer class, which implements a standard network
- * layer.
+ * layer for 1-dimensional or 2-dimensional data.
  */
 #ifndef __MLPACK_METHOS_ANN_LAYER_NEURON_LAYER_HPP
 #define __MLPACK_METHOS_ANN_LAYER_NEURON_LAYER_HPP
@@ -28,29 +29,78 @@ namespace ann /** Artificial Neural Network. */ {
  *  - ReluLayer
  *
  * @tparam ActivationFunction Activation function used for the embedding layer.
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
+ * @tparam DataType Type of data (arma::mat or arma::colvec).
  */
 template <
     class ActivationFunction = LogisticFunction,
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
+    typename DataType = arma::colvec
 >
 class NeuronLayer
-
 {
  public:
   /**
-   * Create the NeuronLayer object using the specified number of neurons.
+   * Create 2-dimensional NeuronLayer object using the specified rows and columns.
+   * In this case, DataType must be aram::mat or other matrix type.
+   *
+   * @param layerRows The number of rows of neurons.
+   * @param layerCols The number of columns of neurons.
+   */
+  NeuronLayer(const size_t layerRows, const size_t layerCols) :
+      layerRows(layerRows), layerCols(layerCols),
+      localInputAcitvations(arma::ones<DataType>(layerRows, layerCols)),
+      inputActivations(localInputAcitvations),
+      localDelta(arma::zeros<DataType>(layerRows, layerCols)),
+      delta(localDelta)
+  {
+    // Nothing to do.
+  }
+  
+  /**
+   * Create 2-dimensional NeuronLayer object using the specified inputActivations and delta.
+   * This allow shared memory among layers, 
+   * which make it easier to combine layers together in some special condition.
+   *
+   * @param inputActivations Outside storage for storing input activations.
+   * @param delta Outside storage for storing delta, 
+   *        the passed error in backward propagation.
+   */
+  NeuronLayer(DataType& inputActivations, DataType& delta) :
+      layerRows(inputActivations.n_rows),
+      layerCols(inputActivations.n_cols),
+      inputActivations(inputActivations),
+      delta(delta)
+  {
+    // Nothing to do.
+  }
+  
+  /**
+   * Create 1-dimensional NeuronLayer object using the specified layer size.
+   * In this case, DataType must be aram::colvec or other vector type.
    *
    * @param layerSize The number of neurons.
    */
   NeuronLayer(const size_t layerSize) :
-      inputActivations(arma::zeros<VecType>(layerSize)),
-      delta(arma::zeros<VecType>(layerSize)),
-      layerSize(layerSize)
+      layerRows(layerSize), layerCols(1),
+      localInputAcitvations(arma::ones<DataType>(layerRows)),
+      inputActivations(localInputAcitvations),
+      localDelta(arma::zeros<DataType>(layerRows)),
+      delta(localDelta)
+  {
+    // Nothing to do.
+  }
+  
+  /**
+   * Copy Constructor
+   */
+  NeuronLayer(const NeuronLayer& l) :
+      layerRows(l.layerRows), layerCols(l.layerCols),
+      localInputAcitvations(l.localInputAcitvations),
+      inputActivations(l.localInputAcitvations.n_elem == 0 ?
+                       l.inputActivations : localInputAcitvations),
+      localDelta(l.localDelta),
+      delta(l.localDelta.n_elem == 0 ? l.delta : localDelta)
   {
-    // Nothing to do here.
+    // Nothing to do.
   }
 
   /**
@@ -61,7 +111,7 @@ class NeuronLayer
    * activity function.
    * @param outputActivation Data to store the resulting output activation.
    */
-  void FeedForward(const VecType& inputActivation, VecType& outputActivation)
+  void FeedForward(const DataType& inputActivation, DataType& outputActivation)
   {
     ActivationFunction::fn(inputActivation, outputActivation);
   }
@@ -73,48 +123,65 @@ class NeuronLayer
    *
    * @param inputActivation Input data used for calculating the function f(x).
    * @param error The backpropagated error.
-   * @param delta The calculating delta using the partial derivative of the
-   * error with respect to a weight.
+   * @param delta The passed error in backward propagation.
    */
-  void FeedBackward(const VecType& inputActivation,
-                    const VecType& error,
-                    VecType& delta)
+  void FeedBackward(const DataType& inputActivation,
+                    const DataType& error,
+                    DataType& delta)
   {
-    VecType derivative;
+    DataType derivative;
     ActivationFunction::deriv(inputActivation, derivative);
-
     delta = error % derivative;
   }
 
   //! Get the input activations.
-  VecType& InputActivation() const { return inputActivations; }
-  //  //! Modify the input activations.
-  VecType& InputActivation() { return inputActivations; }
+  DataType& InputActivation() const { return inputActivations; }
+  //! Modify the input activations.
+  DataType& InputActivation() { return inputActivations; }
 
-  //! Get the detla.
-  VecType& Delta() const { return delta; }
- //  //! Modify the delta.
-  VecType& Delta() { return delta; }
+  //! Get the error passed in backward propagation.
+  DataType& Delta() const { return delta; }
+  //! Modify the error passed in backward propagation.
+  DataType& Delta() { return delta; }
 
-  //! Get input size.
-  size_t InputSize() const { return layerSize; }
-  //  //! Modify the delta.
-  size_t& InputSize() { return layerSize; }
+  //! Get the number of layer rows.
+  size_t LayerRows() const { return layerRows; }
 
-  //! Get output size.
-  size_t OutputSize() const { return layerSize; }
-  //! Modify the output size.
-  size_t& OutputSize() { return layerSize; }
+  //! Get the number of layer colums.
+  size_t LayerCols() const { return layerCols; }
+  
+  /**
+   * Get the number of layer size.
+   * Only for 1-dimsenional type.
+   */
+  size_t InputSize() const { return layerRows; }
+  
+  /**
+   * Get the number of lyaer size.
+   * Only for 1-dimsenional type.
+   */
+  size_t OutputSize() const { return layerRows; }
 
  private:
+  //! Locally-stored number of layer rows.
+  size_t layerRows;
+  
+  //! Locally-stored number of layer cols.
+  size_t layerCols;
+  
   //! Locally-stored input activation object.
-  VecType inputActivations;
+  DataType localInputAcitvations;
+  
+  //! Reference to locally-stored or outside input activation object.
+  DataType& inputActivations;
   
   //! Locally-stored delta object.
-  VecType delta;
+  DataType localDelta;
+  
+  //! Reference to locally-stored or outside delta object.
+  DataType& delta;
+  
 
-  //! Locally-stored number of neurons.
-  size_t layerSize;
 }; // class NeuronLayer
 
 // Convenience typedefs.
@@ -124,20 +191,18 @@ class NeuronLayer
  */
 template <
     class ActivationFunction = LogisticFunction,
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
+    typename DataType = arma::colvec
 >
-using InputLayer = NeuronLayer<ActivationFunction, MatType, VecType>;
+using InputLayer = NeuronLayer<ActivationFunction, DataType>;
 
 /**
  * Standard Hidden-Layer using the logistic activation function.
  */
 template <
     class ActivationFunction = LogisticFunction,
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
+    typename DataType = arma::colvec
 >
-using HiddenLayer = NeuronLayer<ActivationFunction, MatType, VecType>;
+using HiddenLayer = NeuronLayer<ActivationFunction, DataType>;
 
 /**
  * Layer of rectified linear units (relu) using the rectifier activation
@@ -145,10 +210,9 @@ using HiddenLayer = NeuronLayer<ActivationFunction, MatType, VecType>;
  */
 template <
     class ActivationFunction = RectifierFunction,
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
+    typename DataType = arma::colvec
 >
-using ReluLayer = NeuronLayer<ActivationFunction, MatType, VecType>;
+using ReluLayer = NeuronLayer<ActivationFunction, DataType>;
 
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
similarity index 58%
copy from src/mlpack/methods/ann/layer/binary_classification_layer.hpp
copy to src/mlpack/methods/ann/layer/one_hot_layer.hpp
index ecd6064..4113502 100644
--- a/src/mlpack/methods/ann/layer/binary_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/one_hot_layer.hpp
@@ -1,12 +1,12 @@
 /**
- * @file binary_classification_layer.hpp
- * @author Marcus Edel
+ * @file one_hot_layer.hpp
+ * @author Shangtong Zhang
  *
- * Definition of the BinaryClassificationLayer class, which implements a
- * binary class classification layer that can be used as output layer.
+ * Definition of the OneHotLayer class, which implements a standard network
+ * layer.
  */
-#ifndef __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
-#define __MLPACK_METHOS_ANN_LAYER_BINARY_CLASSIFICATION_LAYER_HPP
+#ifndef __MLPACK_METHOS_ANN_LAYER_ONE_HOT_LAYER_HPP
+#define __MLPACK_METHOS_ANN_LAYER_ONE_HOT_LAYER_HPP
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
@@ -15,7 +15,7 @@ namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * An implementation of a binary classification layer that can be used as
+ * An implementation of a one hot classification layer that can be used as
  * output layer.
  *
  * @tparam MatType Type of data (arma::mat or arma::sp_mat).
@@ -25,13 +25,13 @@ template <
     typename MatType = arma::mat,
     typename VecType = arma::colvec
 >
-class BinaryClassificationLayer
+class OneHotLayer
 {
  public:
   /**
-   * Create the BinaryClassificationLayer object.
+   * Create the OneHotLayer object.
    */
-  BinaryClassificationLayer()
+  OneHotLayer()
   {
     // Nothing to do here.
   }
@@ -45,7 +45,7 @@ class BinaryClassificationLayer
    * @param error The calculated error with respect to the input activation and
    * the given target.
    */
-  void CalculateError(const VecType& inputActivations,
+  void calculateError(const VecType& inputActivations,
                       const VecType& target,
                       VecType& error)
   {
@@ -58,25 +58,30 @@ class BinaryClassificationLayer
    * @param inputActivations Input data used to calculate the output class.
    * @param output Output class of the input activation.
    */
-  void OutputClass(const VecType& inputActivations, VecType& output)
+  void outputClass(const VecType& inputActivations, VecType& output)
   {
-    output = inputActivations;
-    output.transform( [](double value) { return (value > 0.5 ? 1 : 0); } );
+    output = arma::zeros<VecType>(inputActivations.n_elem);
+    arma::uword maxIndex;
+    inputActivations.max(maxIndex);
+    output(maxIndex) = 1;
   }
-}; // class BinaryClassificationLayer
+}; // class OneHotLayer
 
-//! Layer traits for the binary class classification layer.
-template <typename MatType, typename VecType>
-class LayerTraits<BinaryClassificationLayer<MatType, VecType> >
+//! Layer traits for the one-hot class classification layer.
+template <
+    typename MatType,
+    typename VecType
+>
+class LayerTraits<OneHotLayer<MatType, VecType> >
 {
  public:
   static const bool IsBinary = true;
   static const bool IsOutputLayer = true;
   static const bool IsBiasLayer = false;
-  static const bool IsLSTMLayer = false;
 };
 
 }; // namespace ann
 }; // namespace mlpack
 
+
 #endif



More information about the mlpack-git mailing list