[mlpack-svn] master: Add implementation of the iRPROP+ method to update the weights. (db63858)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jan 1 07:14:42 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e5279f82fd462f76ed757b66420a68494e7329b9...23b900168ef50d2ad1b247c450645de621e2043e
>---------------------------------------------------------------
commit db63858a25fac8828a55db342d8a4d990f8ef089
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Thu Jan 1 13:13:32 2015 +0100
Add implementation of the iRPROP+ method to update the weights.
>---------------------------------------------------------------
db63858a25fac8828a55db342d8a4d990f8ef089
.../ann/optimizer/{irpropm.hpp => irpropp.hpp} | 52 ++++++++++++++--------
1 file changed, 34 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/methods/ann/optimizer/irpropm.hpp b/src/mlpack/methods/ann/optimizer/irpropp.hpp
similarity index 59%
copy from src/mlpack/methods/ann/optimizer/irpropm.hpp
copy to src/mlpack/methods/ann/optimizer/irpropp.hpp
index 4bb077b..71c95e0 100644
--- a/src/mlpack/methods/ann/optimizer/irpropm.hpp
+++ b/src/mlpack/methods/ann/optimizer/irpropp.hpp
@@ -1,12 +1,12 @@
/**
- * @file irpropm.hpp
+ * @file irpropp.hpp
* @author Marcus Edel
*
* Intialization rule for the neural networks. This simple initialization is
* performed by assigning a random matrix to the weight matrix.
*/
-#ifndef __MLPACK_METHOS_ANN_OPTIMIZER_IRPROPM_HPP
-#define __MLPACK_METHOS_ANN_OPTIMIZER_IRPROPM_HPP
+#ifndef __MLPACK_METHOS_ANN_OPTIMIZER_IRPROPP_HPP
+#define __MLPACK_METHOS_ANN_OPTIMIZER_IRPROPP_HPP
#include <mlpack/core.hpp>
#include <boost/math/special_functions/sign.hpp>
@@ -20,7 +20,7 @@ namespace ann /** Artificial Neural Network. */ {
* @tparam MatType Type of matrix (should be arma::mat or arma::spmat).
*/
template<typename MatType = arma::mat, typename VecType = arma::rowvec>
-class iRPROPm
+class iRPROPp
{
public:
/**
@@ -30,23 +30,25 @@ class iRPROPm
* @param lowerBound The number used as lower bound.
* @param upperBound The number used as upper bound.
*/
- iRPROPm(const size_t cols,
+ iRPROPp(const size_t cols,
const size_t rows,
const double etaMin = 0.5,
const double etaPlus = 1.2,
const double minDelta = 1e-9,
- const double maxDelta = 50) :
- etaMin(etaMin), etaPlus(etaPlus), minDelta(minDelta), maxDelta(maxDelta)
+ const double maxDelta = 50,
+ const double initialUpdate = 0.1) :
+ etaMin(etaMin), etaPlus(etaPlus), minDelta(minDelta), maxDelta(maxDelta), prevError(arma::datum::inf)
{
prevDerivs = arma::zeros<MatType>(rows, cols);
- prevDelta = arma::zeros<MatType>(rows, cols);
+ prevWeightChange = arma::zeros<MatType>(rows, cols);
- prevError = arma::datum::inf;
+ updateValues = arma::ones<MatType>(rows, cols);
+ updateValues.fill(initialUpdate);
}
void UpdateWeights(MatType& weights,
const MatType& gradient,
- const double /* unused */)
+ const double error)
{
MatType derivs = gradient % prevDerivs;
@@ -54,22 +56,32 @@ class iRPROPm
{
for (size_t j(0); j < derivs.n_rows; j++)
{
- if (derivs(j, i) >= 0)
+ if (derivs(j, i) > 0)
{
- prevDelta(j, i) = std::min(prevDelta(j, i) * etaPlus, maxDelta);
+ updateValues(j, i) = std::min(updateValues(j, i) * etaPlus, maxDelta);
+ prevWeightChange(j, i) = boost::math::sign(gradient(j, i)) * updateValues(j, i);
prevDerivs(j, i) = gradient(j, i);
}
- else
+ else if (derivs(j, i) < 0)
{
- prevDelta(j, i) = std::max(prevDelta(j, i) * etaMin, minDelta);
+ updateValues(j, i) = std::max(updateValues(j, i) * etaMin, minDelta);
prevDerivs(j, i) = 0;
+
+ if (error < prevError)
+ prevWeightChange(j, i) = 0;
+ }
+ else
+ {
+ prevWeightChange(j, i) = boost::math::sign(gradient(j, i)) * updateValues(j, i);
+ prevDerivs(j, i) = gradient(j, i);
}
+
+ weights(j, i) -= prevWeightChange(j, i);
}
}
-
- weights -= arma::sign(gradient) % prevDelta;
}
+
private:
//! The number used as learning rate.
const double etaMin;
@@ -82,13 +94,17 @@ class iRPROPm
double prevError;
- MatType prevDelta;
+ MatType updateValues;
+
+ MatType prevWeightChange;
//! weight momentum
MatType prevDerivs;
-}; // class iRPROPm
+}; // class iRPROPp
}; // namespace ann
}; // namespace mlpack
#endif
+
+
More information about the mlpack-svn
mailing list