[mlpack-git] master: Clean the steepest descent optimizer class and add tensor support. (f84079f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun May 3 16:15:39 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f31abbdebcd34e2113d8acf47c1d0b087377921...174d2de995a3fe343cd92d158730f3afa03e622d
>---------------------------------------------------------------
commit f84079f191adef76a60e66ee8a27b2846571b9a2
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Sat May 2 21:53:37 2015 +0200
Clean the steepest descent optimizer class and add tensor support.
>---------------------------------------------------------------
f84079f191adef76a60e66ee8a27b2846571b9a2
.../methods/ann/optimizer/steepest_descent.hpp | 74 +++++++++++++++++-----
1 file changed, 57 insertions(+), 17 deletions(-)
diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
index 8b43897..6a18417 100644
--- a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -2,8 +2,10 @@
* @file steepest_descent.hpp
* @author Marcus Edel
*
- * Intialization rule for the neural networks. This simple initialization is
- * performed by assigning a random matrix to the weight matrix.
+ * Implmentation of the steepest descent optimizer. The method of steepest
+ * descent, also called the gradient descent method, is used to find the
+ * nearest local minimum of a function which the assumtion that the gradient of
+ * the function can be computed.
*/
#ifndef __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
#define __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
@@ -14,33 +16,72 @@ namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
/**
- * This class is used to initialize randomly the weight matrix.
+ * This class is used to update the weights using steepest descent.
*
- * @tparam MatType Type of matrix (should be arma::mat or arma::spmat).
+ * @tparam DataType Type of input data (should be arma::mat,
+ * arma::spmat or arma::cube).
*/
-template<typename MatType = arma::mat, typename VecType = arma::colvec>
+template<typename DataType = arma::mat>
class SteepestDescent
{
public:
+ /*
+ * Construct the optimizer object, which will be used to update the weights.
+ *
+ * @param lr The value used as learning rate (Default: 1).
+ */
+ SteepestDescent(const double lr = 1) : lr(lr), mom(0)
+ {
+ // Nothing to do here.
+ }
+
+ /**
+ * Construct the optimizer object, which will be used to update the weights.
+ *
+ * @param cols The number of cols to initilize the momentum matrix.
+ * @param rows The number of rows to initilize the momentum matrix.
+ * @param lr The value used as learning rate (Default: 1).
+ * @param mom The value used as momentum (Default: 0.1).
+ */
+ SteepestDescent(const size_t cols,
+ const size_t rows,
+ const double lr = 1,
+ const double mom = 0.1) :
+ lr(lr), mom(mom)
+ {
+ if (mom > 0)
+ momWeights = arma::zeros<DataType>(rows, cols);
+ }
+
/**
- * Initialize the random initialization rule with the given lower bound and
- * upper bound.
+ * Construct the optimizer object, which will be used to update the weights.
*
- * @param lowerBound The number used as lower bound.
- * @param upperBound The number used as upper bound.
+ * @param cols The number of cols used to initilize the momentum matrix.
+ * @param rows The number of rows used to initilize the momentum matrix.
+ * @param slices The number of slices used to initilize the momentum matrix.
+ * @param lr The value used as learning rate (Default: 1).
+ * @param mom The value used as momentum (Default: 0.1).
*/
SteepestDescent(const size_t cols,
const size_t rows,
+ const size_t slices,
const double lr = 1,
const double mom = 0.1) :
lr(lr), mom(mom)
{
if (mom > 0)
- momWeights = arma::zeros<MatType>(rows, cols);
+ momWeights = arma::zeros<DataType>(rows, cols, slices);
}
- void UpdateWeights(MatType& weights,
- const MatType& gradient,
+ /*
+ * Update the specified weights using steepest descent.
+ *
+ * @param weights The weights that should be updated.
+ * @param gradient The gradient used to update the weights.
+ */
+ template<typename WeightType, typename GradientType>
+ void UpdateWeights(WeightType& weights,
+ const GradientType& gradient,
const double /* unused */)
{
if (mom > 0)
@@ -53,16 +94,15 @@ class SteepestDescent
weights -= lr * gradient;
}
-
private:
- //! The number used as learning rate.
+ //! The value used as learning rate.
const double lr;
- //! The number used as momentum.
+ //! The value used as momentum.
const double mom;
- //! weight momentum
- MatType momWeights;
+ //! Momentum matrix.
+ DataType momWeights;
}; // class SteepestDescent
}; // namespace ann
More information about the mlpack-git
mailing list