[mlpack-git] master: Add function to update/reset the optimizer object. (e565147)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:38 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1
>---------------------------------------------------------------
commit e5651474c62cd0bcb44a7937dae6266e67a8f3d2
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 16 14:09:30 2015 +0200
Add function to update/reset the optimizer object.
>---------------------------------------------------------------
e5651474c62cd0bcb44a7937dae6266e67a8f3d2
src/mlpack/methods/ann/optimizer/rmsprop.hpp | 57 ++++++++++++++++++++--------
1 file changed, 41 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/methods/ann/optimizer/rmsprop.hpp b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
index e873187..1a25b60 100644
--- a/src/mlpack/methods/ann/optimizer/rmsprop.hpp
+++ b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
@@ -20,7 +20,7 @@ namespace ann /** Artificial Neural Network. */ {
* For more information, see the following.
*
* @code
- * @misc{[tieleman2012,
+ * @misc{tieleman2012,
* title={Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine
* Learning},
* year={2012}
@@ -47,7 +47,7 @@ class RMSPROP
lr(lr),
alpha(alpha),
eps(eps),
- meanSquareGad(function.Weights())
+ meanSquaredGad(function.Weights())
{
// Nothing to do here.
}
@@ -57,16 +57,38 @@ class RMSPROP
*/
void Optimize()
{
- if (meanSquareGad.n_elem == 0)
+ if (meanSquaredGad.n_elem == 0)
{
- meanSquareGad = function.Weights();
- meanSquareGad.zeros();
+ meanSquaredGad = function.Weights();
+ meanSquaredGad.zeros();
}
- DataType gradient;
- function.Gradient(gradient);
+ Optimize(function.Weights(), gradient, meanSquaredGad);
+ }
+
+ /*
+ * Sum up all gradients and store the results in the gradients storage.
+ */
+ void Update()
+ {
+ if (gradient.n_elem != 0)
+ {
+ DataType outputGradient;
+ function.Gradient(outputGradient);
+ gradient += outputGradient;
+ }
+ else
+ {
+ function.Gradient(gradient);
+ }
+ }
- Optimize(function.Weights(), gradient, meanSquareGad);
+ /*
+ * Reset the gradient storage.
+ */
+ void Reset()
+ {
+ gradient.zeros();
}
private:
@@ -81,10 +103,10 @@ class RMSPROP
template<typename eT>
void Optimize(arma::Cube<eT>& weights,
arma::Cube<eT>& gradient,
- arma::Cube<eT>& meanSquareGradient)
+ arma::Cube<eT>& meanSquaredGradient)
{
for (size_t s = 0; s < weights.n_slices; s++)
- Optimize(weights.slice(s), gradient.slice(s), meanSquareGradient.slice(s));
+ Optimize(weights.slice(s), gradient.slice(s), meanSquaredGradient.slice(s));
}
/**
@@ -98,11 +120,11 @@ class RMSPROP
template<typename eT>
void Optimize(arma::Mat<eT>& weights,
arma::Mat<eT>& gradient,
- arma::Mat<eT>& meanSquareGradient)
+ arma::Mat<eT>& meanSquaredGradient)
{
- meanSquareGradient *= alpha;
- meanSquareGradient += (1 - alpha) * (gradient % gradient);
- weights -= lr * gradient / (arma::sqrt(meanSquareGradient) + eps);
+ meanSquaredGradient *= alpha;
+ meanSquaredGradient += (1 - alpha) * (gradient % gradient);
+ weights -= lr * gradient / (arma::sqrt(meanSquaredGradient) + eps);
}
//! The instantiated function.
@@ -117,8 +139,11 @@ class RMSPROP
//! The value used as eps.
const double eps;
- //! The current mean squared error.
- DataType meanSquareGad;
+ //! The current mean squared error of the gradients.
+ DataType meanSquaredGad;
+
+ //! The current gradient.
+ DataType gradient;
}; // class RMSPROP
}; // namespace ann
More information about the mlpack-git
mailing list