[mlpack-git] master: Refactor for new optimizer API. (ebede26)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jun 24 13:50:11 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6e98f6d5e61ac0ca861f0a7c3ec966076eccc50e...7de290f191972dd41856b647249e2d24d2bf029d
>---------------------------------------------------------------
commit ebede269f643e9cb39666a13e8224d14205fa2de
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Fri Jun 19 21:20:26 2015 +0200
Refactor for new optimizer API.
>---------------------------------------------------------------
ebede269f643e9cb39666a13e8224d14205fa2de
.../methods/ann/optimizer/steepest_descent.hpp | 119 ++++++++++++++-------
1 file changed, 78 insertions(+), 41 deletions(-)
diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
index f2ddd81..b2b18d1 100644
--- a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -21,80 +21,114 @@ namespace ann /** Artificial Neural Network. */ {
* @tparam DataType Type of input data (should be arma::mat,
* arma::spmat or arma::cube).
*/
-template<typename DataType = arma::mat>
+template<typename DecomposableFunctionType, typename DataType>
class SteepestDescent
{
public:
- /*
- * Construct the optimizer object, which will be used to update the weights.
+ /**
+ * Construct the SteepestDescent optimizer with the given function and
+ * parameters.
*
- * @param lr The value used as learning rate (Default: 1).
+ * @param function Function to be optimized (minimized).
+ * @param lr The learning rate coefficient.
+ * @param mom The momentum coefficient.
*/
- SteepestDescent(const double lr = 1) : lr(lr), mom(0)
+ SteepestDescent(DecomposableFunctionType& function,
+ const double lr = 1,
+ const double mom = 0) :
+ function(function),
+ lr(lr),
+ mom(mom),
+ momWeights(function.Weights())
+
{
// Nothing to do here.
}
/**
- * Construct the optimizer object, which will be used to update the weights.
- *
- * @param cols The number of cols to initilize the momentum matrix.
- * @param rows The number of rows to initilize the momentum matrix.
- * @param lr The value used as learning rate (Default: 1).
- * @param mom The value used as momentum (Default: 0.1).
+ * Optimize the given function using steepest descent.
*/
- SteepestDescent(const size_t cols,
- const size_t rows,
- const double lr = 1,
- const double mom = 0.1) :
- lr(lr), mom(mom)
+ void Optimize()
{
- if (mom > 0)
- momWeights = arma::zeros<DataType>(rows, cols);
+ if (momWeights.n_elem == 0)
+ {
+ momWeights = function.Weights();
+ momWeights.zeros();
+ }
+
+ Optimize(function.Weights(), gradient, momWeights);
}
- /**
- * Construct the optimizer object, which will be used to update the weights.
- *
- * @param cols The number of cols used to initilize the momentum matrix.
- * @param rows The number of rows used to initilize the momentum matrix.
- * @param slices The number of slices used to initilize the momentum matrix.
- * @param lr The value used as learning rate (Default: 1).
- * @param mom The value used as momentum (Default: 0.1).
+ /*
+ * Sum up all gradients and store the results in the gradients storage.
*/
- SteepestDescent(const size_t cols,
- const size_t rows,
- const size_t slices,
- const double lr,
- const double mom) :
- lr(lr), mom(mom)
+ void Update()
{
- if (mom > 0)
- momWeights = arma::zeros<DataType>(rows, cols, slices);
+ if (gradient.n_elem != 0)
+ {
+ DataType outputGradient;
+ function.Gradient(outputGradient);
+ gradient += outputGradient;
+ }
+ else
+ {
+ function.Gradient(gradient);
+ }
}
/*
- * Update the specified weights using steepest descent.
+ * Reset the gradient storage.
+ */
+ void Reset()
+ {
+ gradient.zeros();
+ }
+
+ private:
+ /** Optimize the given function using steepest descent.
+ *
+ * @param weights The weights that should be updated.
+ * @param gradient The gradient used to update the weights.
+ * @param gradient The moving average over the root mean squared gradient used
+ * to update the weights.
+ */
+ template<typename eT>
+ void Optimize(arma::Cube<eT>& weights,
+ arma::Cube<eT>& gradient,
+ arma::Cube<eT>& momWeights)
+ {
+ for (size_t s = 0; s < weights.n_slices; s++)
+ Optimize(weights.slice(s), gradient.slice(s), momWeights.slice(s));
+ }
+
+ /**
+ * Optimize the given function using steepest descent.
*
* @param weights The weights that should be updated.
* @param gradient The gradient used to update the weights.
+ * @param gradient The moving average over the root mean squared gradient used
+ * to update the weights.
*/
- template<typename WeightType, typename GradientType>
- void UpdateWeights(WeightType& weights,
- const GradientType& gradient,
- const double /* unused */)
+ template<typename eT>
+ void Optimize(arma::Mat<eT>& weights,
+ arma::Mat<eT>& gradient,
+ arma::Mat<eT>& momWeights)
{
if (mom > 0)
{
momWeights *= mom;
- momWeights += lr * gradient;
+ momWeights += (lr * gradient);
weights -= momWeights;
}
else
+ {
weights -= lr * gradient;
+ }
}
- private:
+ //! The instantiated function.
+ DecomposableFunctionType& function;
+
//! The value used as learning rate.
const double lr;
@@ -103,6 +137,9 @@ class SteepestDescent
//! Momentum matrix.
DataType momWeights;
+
+ //! The current gradient.
+ DataType gradient;
}; // class SteepestDescent
}; // namespace ann
More information about the mlpack-git
mailing list