[mlpack-git] master: Add the VRClassRewardLayer class which implements the REINFORCE algoritm for classification models. To be precise, this is is a Variance Reduces classification reinforcement learning rule. (b69c6dc)

gitdub at mlpack.org gitdub at mlpack.org
Fri May 20 15:38:02 EDT 2016

Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/986620375ce84cdc75fdfd99f63f17b5c8ee507a...989dd35359ee0c2258616ea57675f639ff47bfaa


commit b69c6dce9d62f819433d1bdc5ed233b2dd941422
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Apr 19 00:00:32 2016 +0200

    Add the VRClassRewardLayer class which implements the REINFORCE algoritm for classification models. To be precise, this is is a Variance Reduces classification reinforcement learning rule.


 .../methods/ann/layer/vr_class_reward_layer.hpp    | 167 +++++++++++++++++++++
 1 file changed, 167 insertions(+)

diff --git a/src/mlpack/methods/ann/layer/vr_class_reward_layer.hpp b/src/mlpack/methods/ann/layer/vr_class_reward_layer.hpp
new file mode 100644
index 0000000..9b1451c
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/vr_class_reward_layer.hpp
@@ -0,0 +1,167 @@
+ * @file vr_class_reward_layer.hpp
+ * @author Marcus Edel
+ *
+ * Definition of the VRClassRewardLayer class, which implements the variance
+ * reduced classification reinforcement layer.
+ */
+#include <mlpack/core.hpp>
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+ * Implementation of the variance reduced classification reinforcement layer.
+ * This layer is meant to be used in combination with the reinforce normal layer
+ * (ReinforceNormalLayer), which expects that an reward:
+ * (1 for success, 0 otherwise).
+ *
+ * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ */
+template <
+    typename InputDataType = arma::field<arma::mat>,
+    typename OutputDataType = arma::field<arma::mat>
+class VRClassRewardLayer
+ public:
+  /**
+   * Create the VRClassRewardLayer object.
+   *
+   * @param scale Parameter used to scale the reward.
+   * @param sizeAverage Take the average over all batches.
+   */
+  VRClassRewardLayer(const double scale = 1, const bool sizeAverage = true) :
+      scale(scale),
+      sizeAverage(sizeAverage)
+  {
+    // Nothing to do here.
+  }
+  /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f.
+   *
+   * @param input Input data that contains the log-probabilities for each class.
+   * @param target The target vector, that contains the class index in the range
+   *        between 1 and the number of classes.
+   */
+  template<typename eT>
+  double Forward(const arma::field<arma::Mat<eT> >& input,
+                 const arma::Mat<eT>& target)
+  {
+    return Forward(input(0, 0), target);
+  }
+  /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f.
+   *
+   * @param input Input data that contains the log-probabilities for each class.
+   * @param target The target vector, that contains the class index in the range
+   *        between 1 and the number of classes.
+   */
+  template<typename eT>
+  double Forward(const arma::Mat<eT>& input, const arma::Mat<eT>& target)
+  {
+    double output = 0;
+    reward = 0;
+    arma::uword index;
+    for (size_t i = 0; i < input.n_cols; i++)
+    {
+      input.unsafe_col(i).max(index);
+      reward = ((index + 1) == target(i)) * scale;
+    }   
+    if (sizeAverage)
+    {
+      return -reward / input.n_cols;
+    }
+    return -reward;
+  }
+  /**
+   * Ordinary feed backward pass of a neural network, calculating the function
+   * f(x) by propagating x backwards through f. Using the results from the feed
+   * forward pass.
+   *
+   * @param input The propagated input activation.
+   * @param gy The backpropagated error.
+   * @param g The calculated gradient.
+   */
+  template<typename eT>
+  double Backward(const arma::field<arma::Mat<eT> >& input,
+                const arma::Mat<eT>& /* gy */,
+                arma::field<arma::Mat<eT> >& g)
+  {    
+    g = arma::field<arma::Mat<eT> >(2, 1);
+    g(0, 0) = arma::zeros(input(0, 0).n_rows, input(0, 0).n_cols);
+    double vrReward = reward - arma::as_scalar(input(1, 0));
+    if (sizeAverage)
+    {
+      vrReward /= input(0, 0).n_cols;
+    }
+    const double norm = sizeAverage ? 2.0 / input.n_cols : 2.0;
+    g(1, 0) = norm * (input(1, 0) - reward);
+    return vrReward;
+  }
+  //! Get the input parameter.
+  InputDataType& InputParameter() const {return inputParameter; }
+  //! Modify the input parameter.
+  InputDataType& InputParameter() { return inputParameter; }
+  //! Get the output parameter.
+  OutputDataType& OutputParameter() const {return outputParameter; }
+  //! Modify the output parameter.
+  OutputDataType& OutputParameter() { return outputParameter; }
+  //! Get the delta.
+  OutputDataType& Delta() const {return delta; }
+  //! Modify the delta.
+  OutputDataType& Delta() { return delta; }
+  //! Get the value of the deterministic parameter.
+  bool Deterministic() const { return deterministic; }
+  //! Modify the value of the deterministic parameter.
+  bool& Deterministic() { return deterministic; }
+ private:
+  //! Locally-stored value to scale the reward.
+  const double scale;
+  //! If true take the average over all batches.
+  const bool sizeAverage;
+  //! Locally stored reward parameter.
+  double reward;
+  //! Locally-stored delta object.
+  OutputDataType delta;
+  //! Locally-stored input parameter object.
+  InputDataType inputParameter;
+  //! Locally-stored output parameter object.
+  OutputDataType outputParameter;
+  //! If true dropout and scaling is disabled, see notes above.
+  bool deterministic;
+}; // class VRClassRewardLayer
+}; // namespace ann
+}; // namespace mlpack

More information about the mlpack-git mailing list