[mlpack-git] master: Refactor for new optimizer API. (ebede26)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:11 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d

>---------------------------------------------------------------

commit ebede269f643e9cb39666a13e8224d14205fa2de
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Jun 19 21:20:26 2015 +0200

    Refactor for new optimizer API.


>---------------------------------------------------------------

ebede269f643e9cb39666a13e8224d14205fa2de
 .../methods/ann/optimizer/steepest_descent.hpp     | 119 ++++++++++++++-------
 1 file changed, 78 insertions(+), 41 deletions(-)

diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
index f2ddd81..b2b18d1 100644
--- a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -21,80 +21,114 @@ namespace ann /** Artificial Neural Network. */ {
  * @tparam DataType Type of input data (should be arma::mat,
  * arma::spmat or arma::cube).
  */
-template<typename DataType = arma::mat>
+template<typename DecomposableFunctionType, typename DataType>
 class SteepestDescent
 {
  public:
-  /*
-   * Construct the optimizer object, which will be used to update the weights.
+  /**
+   * Construct the SteepestDescent optimizer with the given function and
+   * parameters.
    *
-   * @param lr The value used as learning rate (Default: 1).
+   * @param function Function to be optimized (minimized).
+   * @param lr The learning rate coefficient.
+   * @param mom The momentum coefficient.
    */
-  SteepestDescent(const double lr = 1) : lr(lr), mom(0)
+  SteepestDescent(DecomposableFunctionType& function,
+                  const double lr = 1,
+                  const double mom = 0) :
+      function(function),
+      lr(lr),
+      mom(mom),
+      momWeights(function.Weights())
+
   {
     // Nothing to do here.
   }
 
   /**
-   * Construct the optimizer object, which will be used to update the weights.
-   *
-   * @param cols The number of cols to initilize the momentum matrix.
-   * @param rows The number of rows to initilize the momentum matrix.
-   * @param lr The value used as learning rate (Default: 1).
-   * @param mom The value used as momentum (Default: 0.1).
+   * Optimize the given function using steepest descent.
    */
-  SteepestDescent(const size_t cols,
-                  const size_t rows,
-                  const double lr = 1,
-                  const double mom = 0.1) :
-      lr(lr), mom(mom)
+  void Optimize()
   {
-    if (mom > 0)
-      momWeights = arma::zeros<DataType>(rows, cols);
+    if (momWeights.n_elem == 0)
+    {
+      momWeights = function.Weights();
+      momWeights.zeros();
+    }
+
+    Optimize(function.Weights(), gradient, momWeights);
   }
 
-  /**
-   * Construct the optimizer object, which will be used to update the weights.
-   *
-   * @param cols The number of cols used to initilize the momentum matrix.
-   * @param rows The number of rows used to initilize the momentum matrix.
-   * @param slices The number of slices used to initilize the momentum matrix.
-   * @param lr The value used as learning rate (Default: 1).
-   * @param mom The value used as momentum (Default: 0.1).
+  /*
+   * Sum up all gradients and store the results in the gradients storage.
    */
-  SteepestDescent(const size_t cols,
-                  const size_t rows,
-                  const size_t slices,
-                  const double lr,
-                  const double mom) :
-      lr(lr), mom(mom)
+  void Update()
   {
-    if (mom > 0)
-      momWeights = arma::zeros<DataType>(rows, cols, slices);
+    if (gradient.n_elem != 0)
+    {
+      DataType outputGradient;
+      function.Gradient(outputGradient);
+      gradient += outputGradient;
+    }
+    else
+    {
+      function.Gradient(gradient);
+    }
   }
 
   /*
-   * Update the specified weights using steepest descent.
+   * Reset the gradient storage.
+   */
+  void Reset()
+  {
+    gradient.zeros();
+  }
+
+ private:
+  /** Optimize the given function using steepest descent.
+   *
+   * @param weights The weights that should be updated.
+   * @param gradient The gradient used to update the weights.
+   * @param gradient The moving average over the root mean squared gradient used
+   *    to update the weights.
+   */
+  template<typename eT>
+  void Optimize(arma::Cube<eT>& weights,
+                arma::Cube<eT>& gradient,
+                arma::Cube<eT>& momWeights)
+  {
+    for (size_t s = 0; s < weights.n_slices; s++)
+      Optimize(weights.slice(s), gradient.slice(s), momWeights.slice(s));
+  }
+
+  /**
+   * Optimize the given function using steepest descent.
    *
    * @param weights The weights that should be updated.
    * @param gradient The gradient used to update the weights.
+   * @param gradient The moving average over the root mean squared gradient used
+   *    to update the weights.
    */
-  template<typename WeightType, typename GradientType>
-  void UpdateWeights(WeightType& weights,
-                     const GradientType& gradient,
-                     const double /* unused */)
+  template<typename eT>
+  void Optimize(arma::Mat<eT>& weights,
+                arma::Mat<eT>& gradient,
+                arma::Mat<eT>& momWeights)
   {
     if (mom > 0)
     {
       momWeights *= mom;
-      momWeights += lr * gradient;
+      momWeights += (lr * gradient);
       weights -= momWeights;
     }
     else
+    {
       weights -= lr * gradient;
+    }
   }
 
- private:
+  //! The instantiated function.
+  DecomposableFunctionType& function;
+
   //! The value used as learning rate.
   const double lr;
 
@@ -103,6 +137,9 @@ class SteepestDescent
 
   //! Momentum matrix.
   DataType momWeights;
+
+  //! The current gradient.
+  DataType gradient;
 }; // class SteepestDescent
 
 }; // namespace ann



More information about the mlpack-git mailing list