[mlpack-svn] master: Add implementation of the steepest descent method to update the weights. (23b9001)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jan 1 07:14:48 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e5279f82fd462f76ed757b66420a68494e7329b9...23b900168ef50d2ad1b247c450645de621e2043e
>---------------------------------------------------------------
commit 23b900168ef50d2ad1b247c450645de621e2043e
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Thu Jan 1 13:14:34 2015 +0100
Add implementation of the steepest descent method to update the weights.
>---------------------------------------------------------------
23b900168ef50d2ad1b247c450645de621e2043e
.../methods/ann/optimizer/steepest_descent.hpp | 73 ++++++++++++++++++++++
1 file changed, 73 insertions(+)
diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
new file mode 100644
index 0000000..581eec3
--- /dev/null
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -0,0 +1,73 @@
+/**
+ * @file steepest_descent.hpp
+ * @author Marcus Edel
+ *
+ * Intialization rule for the neural networks. This simple initialization is
+ * performed by assigning a random matrix to the weight matrix.
+ */
+#ifndef __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
+#define __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+
+/**
+ * This class is used to initialize randomly the weight matrix.
+ *
+ * @tparam MatType Type of matrix (should be arma::mat or arma::spmat).
+ */
+template<typename MatType = arma::mat, typename VecType = arma::colvec>
+class SteepestDescent
+{
+ public:
+ /**
+ * Initialize the random initialization rule with the given lower bound and
+ * upper bound.
+ *
+ * @param lowerBound The number used as lower bound.
+ * @param upperBound The number used as upper bound.
+ */
+ SteepestDescent(const size_t cols,
+ const size_t rows,
+ const double lr = 1,
+ const double mom = 0.1) :
+ lr(lr), mom(mom)
+ {
+ if (mom > 0)
+ momWeights = arma::zeros<MatType>(rows, cols);
+ }
+
+ void UpdateWeights(MatType& weights,
+ const MatType& gradient,
+ const double /* unused */)
+ {
+ if (mom > 0)
+ {
+ momWeights *= mom;
+ momWeights += lr * gradient;
+ weights -= momWeights;
+ }
+ else
+ weights -= lr * gradient;
+ }
+
+
+ private:
+ //! The number used as learning rate.
+ const double lr;
+
+ //! The number used as momentum.
+ const double mom;
+
+ //! weight momentum
+ MatType momWeights;
+}; // class SteepestDescent
+
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
+
+
More information about the mlpack-svn
mailing list