[mlpack-git] master: Add function to update/reset the optimizer object. (e565147)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 16 14:50:38 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9264f7544f7c4d93ff735f00f35b0f5287abf59d...7df836c2f5a2287cda82801ca20f4b4b410cf4e1

>---------------------------------------------------------------

commit e5651474c62cd0bcb44a7937dae6266e67a8f3d2
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 16 14:09:30 2015 +0200

    Add function to update/reset the optimizer object.


>---------------------------------------------------------------

e5651474c62cd0bcb44a7937dae6266e67a8f3d2
 src/mlpack/methods/ann/optimizer/rmsprop.hpp | 57 ++++++++++++++++++++--------
 1 file changed, 41 insertions(+), 16 deletions(-)

diff --git a/src/mlpack/methods/ann/optimizer/rmsprop.hpp b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
index e873187..1a25b60 100644
--- a/src/mlpack/methods/ann/optimizer/rmsprop.hpp
+++ b/src/mlpack/methods/ann/optimizer/rmsprop.hpp
@@ -20,7 +20,7 @@ namespace ann /** Artificial Neural Network. */ {
  * For more information, see the following.
  *
  * @code
- * @misc{[tieleman2012,
+ * @misc{tieleman2012,
  *   title={Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine
  *   Learning},
  *   year={2012}
@@ -47,7 +47,7 @@ class RMSPROP
       lr(lr),
       alpha(alpha),
       eps(eps),
-      meanSquareGad(function.Weights())
+      meanSquaredGad(function.Weights())
   {
     // Nothing to do here.
   }
@@ -57,16 +57,38 @@ class RMSPROP
    */
   void Optimize()
   {
-    if (meanSquareGad.n_elem == 0)
+    if (meanSquaredGad.n_elem == 0)
     {
-      meanSquareGad = function.Weights();
-      meanSquareGad.zeros();
+      meanSquaredGad = function.Weights();
+      meanSquaredGad.zeros();
     }
 
-    DataType gradient;
-    function.Gradient(gradient);
+    Optimize(function.Weights(), gradient, meanSquaredGad);
+  }
+
+  /*
+   * Sum up all gradients and store the results in the gradients storage.
+   */
+  void Update()
+  {
+    if (gradient.n_elem != 0)
+    {
+      DataType outputGradient;
+      function.Gradient(outputGradient);
+      gradient += outputGradient;
+    }
+    else
+    {
+      function.Gradient(gradient);
+    }
+  }
 
-    Optimize(function.Weights(), gradient, meanSquareGad);
+  /*
+   * Reset the gradient storage.
+   */
+  void Reset()
+  {
+    gradient.zeros();
   }
 
  private:
@@ -81,10 +103,10 @@ class RMSPROP
   template<typename eT>
   void Optimize(arma::Cube<eT>& weights,
                 arma::Cube<eT>& gradient,
-                arma::Cube<eT>& meanSquareGradient)
+                arma::Cube<eT>& meanSquaredGradient)
   {
     for (size_t s = 0; s < weights.n_slices; s++)
-      Optimize(weights.slice(s), gradient.slice(s), meanSquareGradient.slice(s));
+      Optimize(weights.slice(s), gradient.slice(s), meanSquaredGradient.slice(s));
   }
 
   /**
@@ -98,11 +120,11 @@ class RMSPROP
   template<typename eT>
   void Optimize(arma::Mat<eT>& weights,
                 arma::Mat<eT>& gradient,
-                arma::Mat<eT>& meanSquareGradient)
+                arma::Mat<eT>& meanSquaredGradient)
   {
-    meanSquareGradient *= alpha;
-    meanSquareGradient += (1 - alpha) * (gradient % gradient);
-    weights -= lr * gradient / (arma::sqrt(meanSquareGradient) + eps);
+    meanSquaredGradient *= alpha;
+    meanSquaredGradient += (1 - alpha) * (gradient % gradient);
+    weights -= lr * gradient / (arma::sqrt(meanSquaredGradient) + eps);
   }
 
   //! The instantiated function.
@@ -117,8 +139,11 @@ class RMSPROP
   //! The value used as eps.
   const double eps;
 
-  //! The current mean squared error.
-  DataType meanSquareGad;
+  //! The current mean squared error of the gradients.
+  DataType meanSquaredGad;
+
+  //! The current gradient.
+  DataType gradient;
 }; // class RMSPROP
 
 }; // namespace ann



More information about the mlpack-git mailing list