[mlpack-git] master: Add implementation of the steepest descent method to update the weights. (5193f33)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:09:22 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 5193f33c14348a593616c828f58e4c76dc199290
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Thu Jan 1 13:14:34 2015 +0100

    Add implementation of the steepest descent method to update the weights.


>---------------------------------------------------------------

5193f33c14348a593616c828f58e4c76dc199290
 .../methods/ann/optimizer/steepest_descent.hpp     | 73 ++++++++++++++++++++++
 1 file changed, 73 insertions(+)

diff --git a/src/mlpack/methods/ann/optimizer/steepest_descent.hpp b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
new file mode 100644
index 0000000..581eec3
--- /dev/null
+++ b/src/mlpack/methods/ann/optimizer/steepest_descent.hpp
@@ -0,0 +1,73 @@
+/**
+ * @file steepest_descent.hpp
+ * @author Marcus Edel
+ *
+ * Intialization rule for the neural networks. This simple initialization is
+ * performed by assigning a random matrix to the weight matrix. 
+ */
+#ifndef __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
+#define __MLPACK_METHOS_ANN_OPTIMIZER_STEEPEST_DESCENT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+
+/**
+ * This class is used to initialize randomly the weight matrix.
+ *
+ * @tparam MatType Type of matrix (should be arma::mat or arma::spmat).
+ */
+template<typename MatType = arma::mat, typename VecType = arma::colvec>
+class SteepestDescent
+{
+ public:
+  /**
+   * Initialize the random initialization rule with the given lower bound and
+   * upper bound.
+   *
+   * @param lowerBound The number used as lower bound.
+   * @param upperBound The number used as upper bound.
+   */
+  SteepestDescent(const size_t cols,
+                  const size_t rows,
+                  const double lr = 1, 
+                  const double mom = 0.1) : 
+      lr(lr), mom(mom)
+  {
+    if (mom > 0)
+      momWeights = arma::zeros<MatType>(rows, cols);
+  }
+
+  void UpdateWeights(MatType& weights,
+                     const MatType& gradient,
+                     const double /* unused */)
+  {
+    if (mom > 0)
+    {
+      momWeights *= mom;
+      momWeights += lr * gradient;
+      weights -= momWeights;
+    }
+    else
+      weights -= lr * gradient;
+  }
+
+
+ private:
+  //! The number used as learning rate.
+  const double lr;
+
+  //! The number used as momentum.
+  const double mom;
+
+  //! weight momentum
+  MatType momWeights;
+}; // class SteepestDescent
+
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
+
+



More information about the mlpack-git mailing list