[mlpack-git] master: Adjust connections; use new optimzer. (3533b01)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1

>---------------------------------------------------------------

commit 3533b01f1e701822ccb03c3a269ac81fe525ea1c
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 16 14:32:26 2015 +0200

    Adjust connections; use new optimzer.


>---------------------------------------------------------------

3533b01f1e701822ccb03c3a269ac81fe525ea1c
 .../methods/ann/connections/bias_connection.hpp    | 60 ++++++++++++---
 .../methods/ann/connections/conv_connection.hpp    | 69 +++++++++++++++--
 .../methods/ann/connections/full_connection.hpp    | 71 +++++++++++++++---
 .../ann/connections/fullself_connection.hpp        | 86 +++++++++++++++++++---
 .../methods/ann/connections/pooling_connection.hpp | 32 ++++++--
 .../methods/ann/connections/self_connection.hpp    | 82 ++++++++++++++++++---
 6 files changed, 343 insertions(+), 57 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/bias_connection.hpp b/src/mlpack/methods/ann/connections/bias_connection.hpp
index eb0b3df..af33e9c 100644
--- a/src/mlpack/methods/ann/connections/bias_connection.hpp
+++ b/src/mlpack/methods/ann/connections/bias_connection.hpp
@@ -12,7 +12,7 @@
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
 #include <mlpack/methods/ann/connections/connection_traits.hpp>
-#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -31,7 +31,7 @@ namespace ann /** Artificial Neural Network. */ {
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType = SteepestDescent<>,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class WeightInitRule = NguyenWidrowInitialization,
     typename MatType = arma::mat
 >
@@ -52,7 +52,11 @@ class BiasConnection
    */
   BiasConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
-                 OptimizerType& optimizer,
+                 OptimizerType<BiasConnection<InputLayerType,
+                                               OutputLayerType,
+                                               OptimizerType,
+                                               WeightInitRule,
+                                               MatType>, MatType>& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
@@ -74,19 +78,30 @@ class BiasConnection
    * weight matrix.
    */
   BiasConnection(InputLayerType& inputLayer,
-               OutputLayerType& outputLayer,
-               WeightInitRule weightInitRule = WeightInitRule()) :
+                 OutputLayerType& outputLayer,
+                 WeightInitRule weightInitRule = WeightInitRule()) :
     inputLayer(inputLayer),
     outputLayer(outputLayer),
-    optimizer(new OptimizerType(outputLayer.InputSize(), inputLayer.LayerRows()
-        * inputLayer.LayerCols() * inputLayer.LayerSlices() *
-        inputLayer.OutputMaps() / outputLayer.LayerCols())),
+    optimizer(new OptimizerType<BiasConnection<InputLayerType,
+                                               OutputLayerType,
+                                               OptimizerType,
+                                               WeightInitRule,
+                                               MatType>, MatType>(*this)),
     ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, outputLayer.OutputMaps(), 1);
   }
 
   /**
+   * Delete the bias connection object and its optimizer.
+   */
+  ~BiasConnection()
+  {
+    if (ownsOptimizer)
+      delete optimizer;
+  }
+
+  /**
    * Ordinary feed forward pass of a neural network, evaluating the function
    * f(x) by propagating the activity forward through f using a dense matrix as
    * input.
@@ -131,7 +146,9 @@ class BiasConnection
   template<typename eT>
   void Gradient(arma::Mat<eT>& gradient)
   {
-    gradient = outputLayer.Delta() * inputLayer.InputActivation().t();
+    arma::Cube<eT> grad;
+    Gradient(grad);
+    gradient = grad.slice(0);
   }
 
   /*
@@ -162,9 +179,24 @@ class BiasConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimzer.
-  OptimizerType& Optimzer() const { return *optimizer; }
+  OptimizerType<BiasConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>& Optimzer() const
+  {
+    return *optimizer;
+  }
+
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return *optimizer; }
+  OptimizerType<BiasConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the detla.
   MatType& Delta() const { return delta; }
@@ -239,7 +271,11 @@ class BiasConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored pointer to the optimzer object.
-  OptimizerType* optimizer;
+  OptimizerType<BiasConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>* optimizer;
 
   //! Parameter that indicates if the class owns a optimizer object.
   bool ownsOptimizer;
diff --git a/src/mlpack/methods/ann/connections/conv_connection.hpp b/src/mlpack/methods/ann/connections/conv_connection.hpp
index 0bd6be2..748e219 100644
--- a/src/mlpack/methods/ann/connections/conv_connection.hpp
+++ b/src/mlpack/methods/ann/connections/conv_connection.hpp
@@ -44,7 +44,7 @@ namespace ann  /** Artificial Neural Network. */{
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType = SteepestDescent<>,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class WeightInitRule = RandomInitialization,
     typename ForwardConvolutionRule = NaiveConvolution<ValidConvolution>,
     typename BackwardConvolutionRule = FFTConvolution<FullConvolution>,
@@ -70,7 +70,14 @@ class ConvConnection
   ConvConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
                  const size_t filterSize,
-                 OptimizerType& optimizer,
+                 OptimizerType<ConvConnection<InputLayerType,
+                                              OutputLayerType,
+                                              OptimizerType,
+                                              WeightInitRule,
+                                              ForwardConvolutionRule,
+                                              BackwardConvolutionRule,
+                                              GradientConvolutionRule,
+                                              DataType>, DataType>& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
@@ -99,7 +106,14 @@ class ConvConnection
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
-      optimizer(new OptimizerType(filterSize, filterSize)),
+      optimizer(new OptimizerType<ConvConnection<InputLayerType,
+                                                 OutputLayerType,
+                                                 OptimizerType,
+                                                 WeightInitRule,
+                                                 ForwardConvolutionRule,
+                                                 BackwardConvolutionRule,
+                                                 GradientConvolutionRule,
+                                                 DataType>, DataType>(*this)),
       ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, filterSize, filterSize,
@@ -213,6 +227,22 @@ class ConvConnection
         gradient.slice(s) /= inputLayer.LayerSlices();
       }
     }
+
+    if (InputLayer().OutputMaps() != 1)
+    {
+      arma::Cube<eT> temp = arma::zeros<arma::Cube<eT> >(weights.n_rows, weights.n_cols,
+        weights.n_slices);
+
+      for (size_t i = 0, g = 0; i < OutputLayer().OutputMaps(); i++)
+      {
+        for (size_t j = i; j < weights.n_slices; j+= OutputLayer().OutputMaps(), g++)
+        {
+          temp.slice(j) = gradient.slice(g);
+        }
+      }
+
+      gradient = temp;
+    }
   }
 
   //! Get the convolution kernel.
@@ -231,9 +261,29 @@ class ConvConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimzer.
-  OptimizerType& Optimzer() const { return *optimizer; }
+  OptimizerType<ConvConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               ForwardConvolutionRule,
+                               BackwardConvolutionRule,
+                               GradientConvolutionRule,
+                               DataType>, DataType>& Optimzer() const
+  {
+    return *optimizer;
+  }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return *optimizer; }
+  OptimizerType<ConvConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               ForwardConvolutionRule,
+                               BackwardConvolutionRule,
+                               GradientConvolutionRule,
+                               DataType>, DataType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the passed error in backward propagation.
   DataType& Delta() const { return delta; }
@@ -280,7 +330,14 @@ class ConvConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimizer.
-  OptimizerType* optimizer;
+  OptimizerType<ConvConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               ForwardConvolutionRule,
+                               BackwardConvolutionRule,
+                               GradientConvolutionRule,
+                               DataType>, DataType>* optimizer;
 
   //! Parameter that indicates if the class owns a optimizer object.
   bool ownsOptimizer;
diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/full_connection.hpp
index a72c8ee..d489e3a 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/full_connection.hpp
@@ -9,7 +9,7 @@
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -28,7 +28,7 @@ namespace ann /** Artificial Neural Network. */ {
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType = SteepestDescent<>,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class WeightInitRule = NguyenWidrowInitialization,
     typename MatType = arma::mat
 >
@@ -49,7 +49,11 @@ class FullConnection
    */
   FullConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
-                 OptimizerType& optimizer,
+                 OptimizerType<FullConnection<InputLayerType,
+                                              OutputLayerType,
+                                              OptimizerType,
+                                              WeightInitRule,
+                                              MatType>, MatType>& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
@@ -78,9 +82,11 @@ class FullConnection
                WeightInitRule weightInitRule = WeightInitRule()) :
     inputLayer(inputLayer),
     outputLayer(outputLayer),
-    optimizer(new OptimizerType(outputLayer.InputSize(), inputLayer.LayerRows()
-        * inputLayer.LayerCols() * inputLayer.LayerSlices() *
-        inputLayer.OutputMaps() / outputLayer.LayerCols())),
+    optimizer(new OptimizerType<FullConnection<InputLayerType,
+                                              OutputLayerType,
+                                              OptimizerType,
+                                              WeightInitRule,
+                                              MatType>, MatType>(*this)),
     ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, outputLayer.InputSize(),
@@ -158,7 +164,7 @@ class FullConnection
   template<typename eT>
   void Gradient(arma::Mat<eT>& gradient)
   {
-    gradient = outputLayer.Delta() * inputLayer.InputActivation().t();
+    GradientDelta(inputLayer.InputActivation(), gradient);
   }
 
   /*
@@ -189,9 +195,23 @@ class FullConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimzer.
-  OptimizerType& Optimzer() const { return *optimizer; }
+  OptimizerType<FullConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>& Optimzer() const
+  {
+    return *optimizer;
+  }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return *optimizer; }
+  OptimizerType<FullConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the detla.
   MatType& Delta() const { return delta; }
@@ -241,6 +261,33 @@ class FullConnection
     Gradient(gradient.slice(0));
   }
 
+  /*
+   * Calculate the gradient (dense matrix) using the output delta
+   * (dense matrix) and the input activation (3rd order tensor).
+   *
+   * @param gradient The calculated gradient.
+   */
+  template<typename eT>
+  void GradientDelta(arma::Cube<eT>& /* unused */, arma::Mat<eT>& gradient)
+  {
+    arma::Cube<eT> grad = arma::Cube<eT>(weights.n_rows, weights.n_cols, 1);
+    Gradient(grad);
+    gradient = grad.slice(0);
+
+  }
+
+  /*
+   * Calculate the gradient (dense matrix) using the output delta
+   * (dense matrix) and the input activation (dense matrix).
+   *
+   * @param gradient The calculated gradient.
+   */
+  template<typename eT>
+  void GradientDelta(arma::Mat<eT>& /* unused */, arma::Mat<eT>& gradient)
+  {
+    gradient = outputLayer.Delta() * inputLayer.InputActivation().t();
+  }
+
   //! Locally-stored weight object.
   MatType weights;
 
@@ -251,7 +298,11 @@ class FullConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored pointer to the optimzer object.
-  OptimizerType* optimizer;
+  OptimizerType<FullConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType>, MatType>* optimizer;
 
   //! Parameter that indicates if the class owns a optimizer object.
   bool ownsOptimizer;
diff --git a/src/mlpack/methods/ann/connections/fullself_connection.hpp b/src/mlpack/methods/ann/connections/fullself_connection.hpp
index 03070ec..fc8da6e 100644
--- a/src/mlpack/methods/ann/connections/fullself_connection.hpp
+++ b/src/mlpack/methods/ann/connections/fullself_connection.hpp
@@ -10,6 +10,7 @@
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -30,7 +31,7 @@ namespace ann /** Artificial Neural Network. */ {
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class WeightInitRule = NguyenWidrowInitialization,
     typename MatType = arma::mat,
     typename VecType = arma::colvec
@@ -39,8 +40,8 @@ class FullselfConnection
 {
  public:
   /**
-   * Create the FullConnection object using the specified input layer, output
-   * layer, optimizer and weight initialize rule.
+   * Create the FullselfConnection object using the specified input layer,
+   * output layer, optimizer and weight initialize rule.
    *
    * @param InputLayerType The input layer which is connected with the output
    * layer.
@@ -51,10 +52,47 @@ class FullselfConnection
    * weight matrix.
    */
   FullselfConnection(InputLayerType& inputLayer,
-                  OutputLayerType& outputLayer,
-                  OptimizerType& optimizer,
-                  WeightInitRule weightInitRule = WeightInitRule()) :
-      inputLayer(inputLayer), outputLayer(outputLayer), optimizer(optimizer)
+                     OutputLayerType& outputLayer,
+                     OptimizerType<FullselfConnection<InputLayerType,
+                                                      OutputLayerType,
+                                                      OptimizerType,
+                                                      WeightInitRule,
+                                                      MatType,
+                                                      VecType>, MatType>& optimizer,
+                     WeightInitRule weightInitRule = WeightInitRule()) :
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(&optimizer),
+      ownsOptimizer(false)
+  {
+    weightInitRule.Initialize(weights, outputLayer.InputSize(),
+        inputLayer.OutputSize());
+  }
+
+  /**
+   * Create the FullselfConnection object using the specified input layer,
+   * output layer and weight initialize rule.
+   *
+   * @param InputLayerType The input layer which is connected with the output
+   * layer.
+   * @param OutputLayerType The output layer which is connected with the input
+   * layer.
+   * @param OptimizerType The optimizer used to update the weight matrix.
+   * @param WeightInitRule The weight initialize rule used to initialize the
+   * weight matrix.
+   */
+  FullselfConnection(InputLayerType& inputLayer,
+                     OutputLayerType& outputLayer,
+                     WeightInitRule weightInitRule = WeightInitRule()) :
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(new OptimizerType<FullselfConnection<InputLayerType,
+                                                     OutputLayerType,
+                                                     OptimizerType,
+                                                     WeightInitRule,
+                                                     MatType,
+                                                     VecType>, MatType>(*this)),
+      ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, outputLayer.InputSize(),
         inputLayer.OutputSize());
@@ -111,13 +149,29 @@ class FullselfConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimzer.
-  OptimizerType& Optimzer() const { return optimizer; }
+  OptimizerType<FullselfConnection<InputLayerType,
+                                   OutputLayerType,
+                                   OptimizerType,
+                                   WeightInitRule,
+                                   MatType,
+                                   VecType>, MatType>& Optimzer() const
+  {
+    return *optimizer;
+  }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return optimizer; }
+  OptimizerType<FullselfConnection<InputLayerType,
+                                   OutputLayerType,
+                                   OptimizerType,
+                                   WeightInitRule,
+                                   MatType,
+                                   VecType>, MatType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the detla.
   VecType& Delta() const { return delta; }
- //  //! Modify the delta.
+  //! Modify the delta.
   VecType& Delta() { return delta; }
 
  private:
@@ -131,7 +185,15 @@ class FullselfConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimzer object.
-  OptimizerType& optimizer;
+  OptimizerType<FullselfConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType,
+                               VecType>, MatType>* optimizer;
+
+  //! Parameter that indicates if the class owns a optimizer object.
+  bool ownsOptimizer;
 
   //! Locally-stored detla object that holds the calculated delta.
   VecType delta;
@@ -142,7 +204,7 @@ class FullselfConnection
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
+    template<typename, typename> class OptimizerType,
     class WeightInitRule,
     typename MatType,
     typename VecType
diff --git a/src/mlpack/methods/ann/connections/pooling_connection.hpp b/src/mlpack/methods/ann/connections/pooling_connection.hpp
index debfc2b..d12d603 100644
--- a/src/mlpack/methods/ann/connections/pooling_connection.hpp
+++ b/src/mlpack/methods/ann/connections/pooling_connection.hpp
@@ -12,7 +12,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/optimizer/steepest_descent.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
-#include <mlpack/methods/ann/pooling_rules/max_pooling.hpp>
+#include <mlpack/methods/ann/pooling_rules/mean_pooling.hpp>
 #include <mlpack/methods/ann/connections/connection_traits.hpp>
 
 namespace mlpack{
@@ -32,8 +32,8 @@ namespace ann /** Artificial Neural Network. */ {
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename PoolingRule = MaxPooling,
-    typename OptimizerType = SteepestDescent<>,
+    typename PoolingRule = MeanPooling,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     typename DataType = arma::cube
 >
 class PoolingConnection
@@ -145,9 +145,23 @@ class PoolingConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimizer.
-  OptimizerType& Optimzer() const { return *optimizer; }
+  OptimizerType<PoolingConnection<InputLayerType,
+                                  OutputLayerType,
+                                  PoolingRule,
+                                  OptimizerType,
+                                  DataType>, DataType>& Optimzer() const
+  {
+    return *optimizer;
+  }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return *optimizer; }
+  OptimizerType<PoolingConnection<InputLayerType,
+                                  OutputLayerType,
+                                  PoolingRule,
+                                  OptimizerType,
+                                  DataType>, DataType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the passed error in backward propagation.
   DataType& Delta() const { return delta; }
@@ -220,7 +234,11 @@ class PoolingConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimizer.
-  OptimizerType* optimizer;
+  OptimizerType<PoolingConnection<InputLayerType,
+                                  OutputLayerType,
+                                  PoolingRule,
+                                  OptimizerType,
+                                  DataType>, DataType>* optimizer;
 
   //! Locally-stored weight object.
   DataType* weights;
@@ -237,7 +255,7 @@ template<
     typename InputLayerType,
     typename OutputLayerType,
     typename PoolingRule,
-    typename OptimizerType,
+    template<typename, typename> class OptimizerType,
     typename DataType
 >
 class ConnectionTraits<
diff --git a/src/mlpack/methods/ann/connections/self_connection.hpp b/src/mlpack/methods/ann/connections/self_connection.hpp
index 0bbcbaf..0300eb0 100644
--- a/src/mlpack/methods/ann/connections/self_connection.hpp
+++ b/src/mlpack/methods/ann/connections/self_connection.hpp
@@ -11,6 +11,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
 #include <mlpack/methods/ann/connections/connection_traits.hpp>
+#include <mlpack/methods/ann/optimizer/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -30,7 +31,7 @@ namespace ann /** Artificial Neural Network. */ {
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
+    template<typename, typename> class OptimizerType = mlpack::ann::RMSPROP,
     class WeightInitRule = NguyenWidrowInitialization,
     typename MatType = arma::mat,
     typename VecType = arma::colvec
@@ -52,11 +53,48 @@ class SelfConnection
    */
   SelfConnection(InputLayerType& inputLayer,
                  OutputLayerType& outputLayer,
-                 OptimizerType& optimizer,
+                 OptimizerType<SelfConnection<InputLayerType,
+                                              OutputLayerType,
+                                              OptimizerType,
+                                              WeightInitRule,
+                                              MatType,
+                                              VecType>, MatType>& optimizer,
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
-      optimizer(optimizer),
+      optimizer(&optimizer),
+      ownsOptimizer(false),
+      connection(1 - arma::eye<MatType>(inputLayer.OutputSize(),
+          inputLayer.OutputSize()))
+  {
+    weightInitRule.Initialize(weights, outputLayer.InputSize(),
+        inputLayer.OutputSize());
+  }
+
+  /**
+   * Create the SelfConnection object using the specified input layer, output
+   * layer, optimizer and weight initialize rule.
+   *
+   * @param InputLayerType The input layer which is connected with the output
+   * layer.
+   * @param OutputLayerType The output layer which is connected with the input
+   * layer.
+   * @param OptimizerType The optimizer used to update the weight matrix.
+   * @param WeightInitRule The weight initialize rule used to initialize the
+   * weight matrix.
+   */
+  SelfConnection(InputLayerType& inputLayer,
+                 OutputLayerType& outputLayer,
+                 WeightInitRule weightInitRule = WeightInitRule()) :
+      inputLayer(inputLayer),
+      outputLayer(outputLayer),
+      optimizer(new OptimizerType<SelfConnection<InputLayerType,
+                                                 OutputLayerType,
+                                                 OptimizerType,
+                                                 WeightInitRule,
+                                                 MatType,
+                                                 VecType>, MatType>(*this)),
+      ownsOptimizer(true),
       connection(1 - arma::eye<MatType>(inputLayer.OutputSize(),
           inputLayer.OutputSize()))
   {
@@ -113,14 +151,30 @@ class SelfConnection
   //! Modify the output layer.
   OutputLayerType& OutputLayer() { return outputLayer; }
 
-  //! Get the optimzer.
-  const OptimizerType& Optimzer() const { return optimizer; }
+    //! Get the optimzer.
+  OptimizerType<SelfConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType,
+                               VecType>, MatType>& Optimzer() const
+  {
+    return *optimizer;
+  }
   //! Modify the optimzer.
-  OptimizerType& Optimzer() { return optimizer; }
+  OptimizerType<SelfConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType,
+                               VecType>, MatType>& Optimzer()
+  {
+    return *optimizer;
+  }
 
   //! Get the detla.
   const VecType& Delta() const { return delta; }
- //  //! Modify the delta.
+  //! Modify the delta.
   VecType& Delta() { return delta; }
 
  private:
@@ -133,8 +187,16 @@ class SelfConnection
   //! Locally-stored connected output layer object.
   OutputLayerType& outputLayer;
 
-  //! Locally-stored optimzer object.
-  OptimizerType& optimizer;
+  //! Locally-stored pointer to the optimzer object.
+  OptimizerType<SelfConnection<InputLayerType,
+                               OutputLayerType,
+                               OptimizerType,
+                               WeightInitRule,
+                               MatType,
+                               VecType>, MatType>* optimizer;
+
+  //! Parameter that indicates if the class owns a optimizer object.
+  bool ownsOptimizer;
 
   //! Locally-stored detla object that holds the calculated delta.
   VecType delta;
@@ -147,7 +209,7 @@ class SelfConnection
 template<
     typename InputLayerType,
     typename OutputLayerType,
-    typename OptimizerType,
+    template<typename, typename> class OptimizerType,
     class WeightInitRule,
     typename MatType,
     typename VecType



More information about the mlpack-git mailing list