[mlpack-git] master: Add RmsProp optimizer. (cd9fb08)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:40 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1
>---------------------------------------------------------------
commit cd9fb08842957fdcbe308d3931bb29597e4028fe
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sun Jun 14 19:12:06 2015 +0200
Add RmsProp optimizer.
>---------------------------------------------------------------
cd9fb08842957fdcbe308d3931bb29597e4028fe
src/mlpack/methods/ann/optimizer/rmsprop.hpp | 127 +++++++++++++++++++++++++++
1 file changed, 127 insertions(+)
diff --git a/src/mlpack/methods/ann/optimizer/rmsprop.hpp b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
new file mode 100644
index 0000000..e873187
--- /dev/null
+++ b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
@@ -0,0 +1,127 @@
+/**
+ * @file rmsprop.hpp
+ * @author Marcus Edel
+ *
+ * Implmentation of the RmsProp optimizer. RmsProp is an optimizer that utilizes
+ * the magnitude of recent gradients to normalize the gradients.
+ */
+#ifndef __MLPACK_METHODS_ANN_OPTIMIZER_RMSPROP_HPP
+#define __MLPACK_METHODS_ANN_OPTIMIZER_RMSPROP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+
+/**
+ * RmsProp is an optimizer that utilizes the magnitude of recent gradients to
+ * normalize the gradients.
+ *
+ * For more information, see the following.
+ *
+ * @code
+ * @misc{[tieleman2012,
+ * title={Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine
+ * Learning},
+ * year={2012}
+ * }
+ * @endcode
+ */
+template<typename DecomposableFunctionType, typename DataType>
+class RMSPROP
+{
+ public:
+ /**
+ * Construct the RMSPROP optimizer with the given function and parameters.
+ *
+ * @param function Function to be optimized (minimized).
+ * @param lr The learning rate coefficient.
+ * @param alpha Constant similar to that used in AdaDelta and Momentum methods.
+ * @param eps The eps coefficient to avoid division by zero.
+ */
+ RMSPROP(DecomposableFunctionType& function,
+ const double lr = 0.01,
+ const double alpha = 0.99,
+ const double eps = 1e-8) :
+ function(function),
+ lr(lr),
+ alpha(alpha),
+ eps(eps),
+ meanSquareGad(function.Weights())
+ {
+ // Nothing to do here.
+ }
+
+ /**
+ * Optimize the given function using RmsProp.
+ */
+ void Optimize()
+ {
+ if (meanSquareGad.n_elem == 0)
+ {
+ meanSquareGad = function.Weights();
+ meanSquareGad.zeros();
+ }
+
+ DataType gradient;
+ function.Gradient(gradient);
+
+ Optimize(function.Weights(), gradient, meanSquareGad);
+ }
+
+ private:
+ /**
+ * Optimize the given function using RmsProp.
+ *
+ * @param weights The weights that should be updated.
+ * @param gradient The gradient used to update the weights.
+ * @param gradient The moving average over the root mean squared gradient used
+ * to update the weights.
+ */
+ template<typename eT>
+ void Optimize(arma::Cube<eT>& weights,
+ arma::Cube<eT>& gradient,
+ arma::Cube<eT>& meanSquareGradient)
+ {
+ for (size_t s = 0; s < weights.n_slices; s++)
+ Optimize(weights.slice(s), gradient.slice(s), meanSquareGradient.slice(s));
+ }
+
+ /**
+ * Optimize the given function using RmsProp.
+ *
+ * @param weights The weights that should be updated.
+ * @param gradient The gradient used to update the weights.
+ * @param gradient The moving average over the root mean squared gradient used
+ * to update the weights.
+ */
+ template<typename eT>
+ void Optimize(arma::Mat<eT>& weights,
+ arma::Mat<eT>& gradient,
+ arma::Mat<eT>& meanSquareGradient)
+ {
+ meanSquareGradient *= alpha;
+ meanSquareGradient += (1 - alpha) * (gradient % gradient);
+ weights -= lr * gradient / (arma::sqrt(meanSquareGradient) + eps);
+ }
+
+ //! The instantiated function.
+ DecomposableFunctionType& function;
+
+ //! The value used as learning rate.
+ const double lr;
+
+ //! The value used as alpha
+ const double alpha;
+
+ //! The value used as eps.
+ const double eps;
+
+ //! The current mean squared error.
+ DataType meanSquareGad;
+}; // class RMSPROP
+
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list