[mlpack-git] master: Clean the steepest descent optimizer class and add tensor support. (f84079f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun May 3 16:15:39 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f31abbdebcd34e2113d8acf47c1d0b087377921...174d2de995a3fe343cd92d158730f3afa03e622d

>---------------------------------------------------------------

commit f84079f191adef76a60e66ee8a27b2846571b9a2
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat May 2 21:53:37 2015 +0200

    Clean the steepest descent optimizer class and add tensor support.


>---------------------------------------------------------------

f84079f191adef76a60e66ee8a27b2846571b9a2
 .../methods/ann/optimizer/steepest_descent.hpp     | 74 +++++++++++++++++-----
 1 file changed, 57 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
index 8b43897..6a18417 100644
--- a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -2,8 +2,10 @@
  * @file steepest_descent.hpp
  * @author Marcus Edel
  *
- * Intialization rule for the neural networks. This simple initialization is
- * performed by assigning a random matrix to the weight matrix.
+ * Implmentation of the steepest descent optimizer. The method of steepest
+ * descent, also called the gradient descent method, is used to find the
+ * nearest local minimum of a function which the assumtion that the gradient of
+ * the function can be computed.
  */
 #ifndef __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
 #define __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
@@ -14,33 +16,72 @@ namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * This class is used to initialize randomly the weight matrix.
+ * This class is used to update the weights using steepest descent.
  *
- * @tparam MatType Type of matrix (should be arma::mat or arma::spmat).
+ * @tparam DataType Type of input data (should be arma::mat,
+ * arma::spmat or arma::cube).
  */
-template<typename MatType = arma::mat, typename VecType = arma::colvec>
+template<typename DataType = arma::mat>
 class SteepestDescent
 {
  public:
+  /*
+   * Construct the optimizer object, which will be used to update the weights.
+   *
+   * @param lr The value used as learning rate (Default: 1).
+   */
+  SteepestDescent(const double lr = 1) : lr(lr), mom(0)
+  {
+    // Nothing to do here.
+  }
+
+  /**
+   * Construct the optimizer object, which will be used to update the weights.
+   *
+   * @param cols The number of cols to initilize the momentum matrix.
+   * @param rows The number of rows to initilize the momentum matrix.
+   * @param lr The value used as learning rate (Default: 1).
+   * @param mom The value used as momentum (Default: 0.1).
+   */
+  SteepestDescent(const size_t cols,
+                  const size_t rows,
+                  const double lr = 1,
+                  const double mom = 0.1) :
+      lr(lr), mom(mom)
+  {
+    if (mom > 0)
+      momWeights = arma::zeros<DataType>(rows, cols);
+  }
+
   /**
-   * Initialize the random initialization rule with the given lower bound and
-   * upper bound.
+   * Construct the optimizer object, which will be used to update the weights.
    *
-   * @param lowerBound The number used as lower bound.
-   * @param upperBound The number used as upper bound.
+   * @param cols The number of cols used to initilize the momentum matrix.
+   * @param rows The number of rows used to initilize the momentum matrix.
+   * @param slices The number of slices used to initilize the momentum matrix.
+   * @param lr The value used as learning rate (Default: 1).
+   * @param mom The value used as momentum (Default: 0.1).
    */
   SteepestDescent(const size_t cols,
                   const size_t rows,
+                  const size_t slices,
                   const double lr = 1,
                   const double mom = 0.1) :
       lr(lr), mom(mom)
   {
     if (mom > 0)
-      momWeights = arma::zeros<MatType>(rows, cols);
+      momWeights = arma::zeros<DataType>(rows, cols, slices);
   }
 
-  void UpdateWeights(MatType& weights,
-                     const MatType& gradient,
+  /*
+   * Update the specified weights using steepest descent.
+   *
+   * @param weights The weights that should be updated.
+   * @param gradient The gradient used to update the weights.
+   */
+  template<typename WeightType, typename GradientType>
+  void UpdateWeights(WeightType& weights,
+                     const GradientType& gradient,
                      const double /* unused */)
   {
     if (mom > 0)
@@ -53,16 +94,15 @@ class SteepestDescent
       weights -= lr * gradient;
   }
 
-
  private:
-  //! The number used as learning rate.
+  //! The value used as learning rate.
   const double lr;
 
-  //! The number used as momentum.
+  //! The value used as momentum.
   const double mom;
 
-  //! weight momentum
-  MatType momWeights;
+  //! Momentum matrix.
+  DataType momWeights;
 }; // class SteepestDescent
 
 }; // namespace ann



More information about the mlpack-git mailing list