[mlpack-git] master: Add function to get the current gradient. (c49adea)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sat Aug 1 16:51:05 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/dd188581a86e64a0e0dc7854e1c7075d6c8bfe90...c49adeab929c536f8ad7497567fd4603e9ff5905

>---------------------------------------------------------------

commit c49adeab929c536f8ad7497567fd4603e9ff5905
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Jul 31 18:02:04 2015 +0200

    Add function to get the current gradient.


>---------------------------------------------------------------

c49adeab929c536f8ad7497567fd4603e9ff5905
 src/mlpack/methods/ann/optimizer/ada_delta.hpp | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/ann/optimizer/ada_delta.hpp b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
index 0d021fc..791253d 100644
--- a/src/mlpack/methods/ann/optimizer/ada_delta.hpp
+++ b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
@@ -54,7 +54,8 @@ class AdaDelta
           const double eps = 1e-6) :
       function(function),
       rho(rho),
-      eps(eps)
+      eps(eps),
+      iteration(0)
   {
     // Nothing to do here.
   }
@@ -72,6 +73,9 @@ class AdaDelta
       meanSquaredGradientDx = meanSquaredGradient;
     }
 
+    if (iteration > 1)
+      gradient /= iteration;
+
     Optimize(function.Weights(), gradient, meanSquaredGradient,
         meanSquaredGradientDx);
   }
@@ -81,6 +85,8 @@ class AdaDelta
    */
   void Update()
   {
+    iteration++;
+
     if (gradient.n_elem != 0)
     {
       DataType outputGradient;
@@ -98,9 +104,15 @@ class AdaDelta
    */
   void Reset()
   {
+    iteration = 0;
     gradient.zeros();
   }
 
+  //! Get the gradient.
+  DataType& Gradient() const { return gradient; }
+  //! Modify the gradient.
+  DataType& Gradient() { return gradient; }
+
  private:
   /**
    * Optimize the given function using AdaDelta.
@@ -168,6 +180,9 @@ class AdaDelta
 
   //! The current mean squared gradient Dx
   DataType meanSquaredGradientDx;
+
+  //! The locally stored number of iterations.
+  size_t iteration;
 }; // class AdaDelta
 
 }; // namespace ann



More information about the mlpack-git mailing list