[mlpack-git] master: Use the number of rows and cols to initialize the optimizer. (b5fbcaa)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat Jun 6 11:16:40 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7fb32130bd683cf03a853ea2bc6960e80d625955...b5fbcaa319689553f44f2d33e5303c2a28e031e1
>---------------------------------------------------------------
commit b5fbcaa319689553f44f2d33e5303c2a28e031e1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat Jun 6 17:16:28 2015 +0200
Use the number of rows and cols to initialize the optimizer.
>---------------------------------------------------------------
b5fbcaa319689553f44f2d33e5303c2a28e031e1
src/mlpack/methods/ann/connections/conv_connection.hpp | 2 +-
src/mlpack/methods/ann/connections/full_connection.hpp | 4 +++-
src/mlpack/methods/ann/connections/pooling_connection.hpp | 10 +++++-----
3 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/ann/connections/conv_connection.hpp b/src/mlpack/methods/ann/connections/conv_connection.hpp
index b7cbd2c..0bd6be2 100644
--- a/src/mlpack/methods/ann/connections/conv_connection.hpp
+++ b/src/mlpack/methods/ann/connections/conv_connection.hpp
@@ -99,7 +99,7 @@ class ConvConnection
WeightInitRule weightInitRule = WeightInitRule()) :
inputLayer(inputLayer),
outputLayer(outputLayer),
- optimizer(new OptimizerType()),
+ optimizer(new OptimizerType(filterSize, filterSize)),
ownsOptimizer(true)
{
weightInitRule.Initialize(weights, filterSize, filterSize,
diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/full_connection.hpp
index 0937b71..9242b76 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/full_connection.hpp
@@ -78,7 +78,9 @@ class FullConnection
WeightInitRule weightInitRule = WeightInitRule()) :
inputLayer(inputLayer),
outputLayer(outputLayer),
- optimizer(new OptimizerType()),
+ optimizer(new OptimizerType(outputLayer.InputSize(), inputLayer.LayerRows()
+ * inputLayer.LayerCols() * inputLayer.LayerSlices() *
+ inputLayer.OutputMaps() / outputLayer.LayerCols())),
ownsOptimizer(true)
{
weightInitRule.Initialize(weights, outputLayer.InputSize(),
diff --git a/src/mlpack/methods/ann/connections/pooling_connection.hpp b/src/mlpack/methods/ann/connections/pooling_connection.hpp
index 8f86c90..debfc2b 100644
--- a/src/mlpack/methods/ann/connections/pooling_connection.hpp
+++ b/src/mlpack/methods/ann/connections/pooling_connection.hpp
@@ -33,7 +33,7 @@ template<
typename InputLayerType,
typename OutputLayerType,
typename PoolingRule = MaxPooling,
- template<typename> class OptimizerType = SteepestDescent,
+ typename OptimizerType = SteepestDescent<>,
typename DataType = arma::cube
>
class PoolingConnection
@@ -145,9 +145,9 @@ class PoolingConnection
OutputLayerType& OutputLayer() { return outputLayer; }
//! Get the optimizer.
- OptimizerType<DataType>& Optimzer() const { return *optimizer; }
+ OptimizerType& Optimzer() const { return *optimizer; }
//! Modify the optimzer.
- OptimizerType<DataType>& Optimzer() { return *optimizer; }
+ OptimizerType& Optimzer() { return *optimizer; }
//! Get the passed error in backward propagation.
DataType& Delta() const { return delta; }
@@ -220,7 +220,7 @@ class PoolingConnection
OutputLayerType& outputLayer;
//! Locally-stored optimizer.
- OptimizerType<DataType>* optimizer;
+ OptimizerType* optimizer;
//! Locally-stored weight object.
DataType* weights;
@@ -237,7 +237,7 @@ template<
typename InputLayerType,
typename OutputLayerType,
typename PoolingRule,
- template<typename> class OptimizerType,
+ typename OptimizerType,
typename DataType
>
class ConnectionTraits<
More information about the mlpack-git
mailing list