[mlpack-git] master: Refactor optimizer for new network API. (98b6773)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Sep 3 08:35:37 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/91ae1062772a0f2bbca9a072769629c2d775ae64...42d61dfdbc9b0cbce59398e67ea58544b0fa1919

>---------------------------------------------------------------

commit 98b67739902b2b8a6b547b8c598389ac452ad294
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Thu Sep 3 14:22:19 2015 +0200

    Refactor optimizer for new network API.


>---------------------------------------------------------------

98b67739902b2b8a6b547b8c598389ac452ad294
 src/mlpack/methods/ann/optimizer/ada_delta.hpp | 18 +++---------------
 1 file changed, 3 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/ann/optimizer/ada_delta.hpp b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
index 9e8b8e5..0e74b6a 100644
--- a/src/mlpack/methods/ann/optimizer/ada_delta.hpp
+++ b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
@@ -54,8 +54,7 @@ class AdaDelta
           const double eps = 1e-6) :
       function(function),
       rho(rho),
-      eps(eps),
-      iteration(0)
+      eps(eps)
   {
     // Nothing to do here.
   }
@@ -73,9 +72,6 @@ class AdaDelta
       meanSquaredGradientDx = meanSquaredGradient;
     }
 
-    if (iteration > 1)
-      gradient /= iteration;
-
     Optimize(function.Weights(), gradient, meanSquaredGradient,
         meanSquaredGradientDx);
   }
@@ -85,17 +81,13 @@ class AdaDelta
    */
   void Update()
   {
-    iteration++;
-
     if (gradient.n_elem != 0)
     {
-      DataType outputGradient;
-      function.Gradient(outputGradient);
-      gradient += outputGradient;
+      gradient += function.Gradient();
     }
     else
     {
-      function.Gradient(gradient);
+      gradient = function.Gradient();
     }
   }
 
@@ -104,7 +96,6 @@ class AdaDelta
    */
   void Reset()
   {
-    iteration = 0;
     gradient.zeros();
   }
 
@@ -180,9 +171,6 @@ class AdaDelta
 
   //! The current mean squared gradient Dx
   DataType meanSquaredGradientDx;
-
-  //! The locally stored number of iterations.
-  size_t iteration;
 }; // class AdaDelta
 
 }; // namespace ann



More information about the mlpack-git mailing list