[mlpack-git] master: Add function to get the current gradient. (c49adea)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat Aug 1 16:51:05 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/dd188581a86e64a0e0dc7854e1c7075d6c8bfe90...c49adeab929c536f8ad7497567fd4603e9ff5905
>---------------------------------------------------------------
commit c49adeab929c536f8ad7497567fd4603e9ff5905
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Fri Jul 31 18:02:04 2015 +0200
Add function to get the current gradient.
>---------------------------------------------------------------
c49adeab929c536f8ad7497567fd4603e9ff5905
src/mlpack/methods/ann/optimizer/ada_delta.hpp | 17 ++++++++++++++++-
1 file changed, 16 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/ann/optimizer/ada_delta.hpp b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
index 0d021fc..791253d 100644
--- a/src/mlpack/methods/ann/optimizer/ada_delta.hpp
+++ b/src/mlpack/methods/ann/optimizer/ada_delta.hpp
@@ -54,7 +54,8 @@ class AdaDelta
const double eps = 1e-6) :
function(function),
rho(rho),
- eps(eps)
+ eps(eps),
+ iteration(0)
{
// Nothing to do here.
}
@@ -72,6 +73,9 @@ class AdaDelta
meanSquaredGradientDx = meanSquaredGradient;
}
+ if (iteration > 1)
+ gradient /= iteration;
+
Optimize(function.Weights(), gradient, meanSquaredGradient,
meanSquaredGradientDx);
}
@@ -81,6 +85,8 @@ class AdaDelta
*/
void Update()
{
+ iteration++;
+
if (gradient.n_elem != 0)
{
DataType outputGradient;
@@ -98,9 +104,15 @@ class AdaDelta
*/
void Reset()
{
+ iteration = 0;
gradient.zeros();
}
+ //! Get the gradient.
+ DataType& Gradient() const { return gradient; }
+ //! Modify the gradient.
+ DataType& Gradient() { return gradient; }
+
private:
/**
* Optimize the given function using AdaDelta.
@@ -168,6 +180,9 @@ class AdaDelta
//! The current mean squared gradient Dx
DataType meanSquaredGradientDx;
+
+ //! The locally stored number of iterations.
+ size_t iteration;
}; // class AdaDelta
}; // namespace ann
More information about the mlpack-git
mailing list