[mlpack-git] master: Use the number of rows and cols to initialize the optimizer. (b5fbcaa)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat Jun 6 11:16:40 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7fb32130bd683cf03a853ea2bc6960e80d625955...b5fbcaa319689553f44f2d33e5303c2a28e031e1

>---------------------------------------------------------------

commit b5fbcaa319689553f44f2d33e5303c2a28e031e1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Jun 6 17:16:28 2015 +0200

    Use the number of rows and cols to initialize the optimizer.


>---------------------------------------------------------------

b5fbcaa319689553f44f2d33e5303c2a28e031e1
 src/mlpack/methods/ann/connections/conv_connection.hpp    |  2 +-
 src/mlpack/methods/ann/connections/full_connection.hpp    |  4 +++-
 src/mlpack/methods/ann/connections/pooling_connection.hpp | 10 +++++-----
 3 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/ann/connections/conv_connection.hpp b/src/mlpack/methods/ann/connections/conv_connection.hpp
index b7cbd2c..0bd6be2 100644
--- a/src/mlpack/methods/ann/connections/conv_connection.hpp
+++ b/src/mlpack/methods/ann/connections/conv_connection.hpp
@@ -99,7 +99,7 @@ class ConvConnection
                  WeightInitRule weightInitRule = WeightInitRule()) :
       inputLayer(inputLayer),
       outputLayer(outputLayer),
-      optimizer(new OptimizerType()),
+      optimizer(new OptimizerType(filterSize, filterSize)),
       ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, filterSize, filterSize,
diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/full_connection.hpp
index 0937b71..9242b76 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/full_connection.hpp
@@ -78,7 +78,9 @@ class FullConnection
                WeightInitRule weightInitRule = WeightInitRule()) :
     inputLayer(inputLayer),
     outputLayer(outputLayer),
-    optimizer(new OptimizerType()),
+    optimizer(new OptimizerType(outputLayer.InputSize(), inputLayer.LayerRows()
+        * inputLayer.LayerCols() * inputLayer.LayerSlices() *
+        inputLayer.OutputMaps() / outputLayer.LayerCols())),
     ownsOptimizer(true)
   {
     weightInitRule.Initialize(weights, outputLayer.InputSize(),
diff --git a/src/mlpack/methods/ann/connections/pooling_connection.hpp b/src/mlpack/methods/ann/connections/pooling_connection.hpp
index 8f86c90..debfc2b 100644
--- a/src/mlpack/methods/ann/connections/pooling_connection.hpp
+++ b/src/mlpack/methods/ann/connections/pooling_connection.hpp
@@ -33,7 +33,7 @@ template<
     typename InputLayerType,
     typename OutputLayerType,
     typename PoolingRule = MaxPooling,
-    template<typename> class OptimizerType = SteepestDescent,
+    typename OptimizerType = SteepestDescent<>,
     typename DataType = arma::cube
 >
 class PoolingConnection
@@ -145,9 +145,9 @@ class PoolingConnection
   OutputLayerType& OutputLayer() { return outputLayer; }
 
   //! Get the optimizer.
-  OptimizerType<DataType>& Optimzer() const { return *optimizer; }
+  OptimizerType& Optimzer() const { return *optimizer; }
   //! Modify the optimzer.
-  OptimizerType<DataType>& Optimzer() { return *optimizer; }
+  OptimizerType& Optimzer() { return *optimizer; }
 
   //! Get the passed error in backward propagation.
   DataType& Delta() const { return delta; }
@@ -220,7 +220,7 @@ class PoolingConnection
   OutputLayerType& outputLayer;
 
   //! Locally-stored optimizer.
-  OptimizerType<DataType>* optimizer;
+  OptimizerType* optimizer;
 
   //! Locally-stored weight object.
   DataType* weights;
@@ -237,7 +237,7 @@ template<
     typename InputLayerType,
     typename OutputLayerType,
     typename PoolingRule,
-    template<typename> class OptimizerType,
+    typename OptimizerType,
     typename DataType
 >
 class ConnectionTraits<



More information about the mlpack-git mailing list