[mlpack-git] master: The gradient to update the weight depends on the number of output maps. (3ade929)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Jun 9 14:14:06 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1d5349c98666e8d592117e23ad9d6e9d0dae36f8...3ade9299e8f3c1e73ba30bff276b51813ede87b5

>---------------------------------------------------------------

commit 3ade9299e8f3c1e73ba30bff276b51813ede87b5
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Tue Jun 9 17:46:23 2015 +0200

    The gradient to update the weight depends on the number of output maps.


>---------------------------------------------------------------

3ade9299e8f3c1e73ba30bff276b51813ede87b5
 src/mlpack/methods/ann/cnn.hpp | 83 +++++++++++++++++++++++++++++++++++++-----
 1 file changed, 74 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp
index c2a6caa..e80ae2f 100644
--- a/src/mlpack/methods/ann/cnn.hpp
+++ b/src/mlpack/methods/ann/cnn.hpp
@@ -109,8 +109,8 @@ class CNN
      * @param output Output data used to store the output activation
      * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
-    template <typename VecType>
-    void Predict(const VecType& input, VecType& output)
+    template <typename InputType, typename OutputType>
+    void Predict(const InputType& input, OutputType& output)
     {
       ResetActivations(network);
 
@@ -181,6 +181,7 @@ class CNN
     Reset(std::tuple<Tp...>& t)
     {
       std::get<I>(t).OutputLayer().InputActivation().zeros();
+      std::get<I>(t).Delta().zeros();
       Reset<I + 1, Tp...>(t);
     }
 
@@ -204,6 +205,7 @@ class CNN
     {
       ConnectionForward(std::get<I>(t));
 
+
       // Use the first connection to perform the feed forward algorithm.
       std::get<0>(std::get<I>(t)).OutputLayer().FeedForward(
           std::get<0>(std::get<I>(t)).OutputLayer().InputActivation(),
@@ -254,8 +256,8 @@ class CNN
     /*
      * Calculate and store the output activation.
      */
-    template<typename VecType, typename... Tp>
-    void OutputPrediction(std::tuple<Tp...>& t, VecType& output)
+    template<typename OutputType, typename... Tp>
+    void OutputPrediction(std::tuple<Tp...>& t, OutputType& output)
     {
        // Calculate and store the output prediction.
       outputLayer.OutputClass(std::get<0>(
@@ -402,6 +404,48 @@ class CNN
     }
 
     /**
+     * Helper function to update the weights using the gradients from the
+     * gradient store.
+     *
+     * enable_if (SFINAE) is used to select between two template overloads of
+     * the get function - one for when I is equal the size of the tuple of
+     * connections, and one for the general case which peels off the first type
+     * and recurses, as usual with variadic function templates.
+     */
+    template<size_t I = 0, typename eT, typename... Tp>
+    void UpdateWeights(arma::Mat<eT>& weights, std::tuple<Tp...>& t)
+    {
+      std::get<I>(t).Optimzer().UpdateWeights(weights,
+          gradients[gradientNum].slice(0), trainError);
+    }
+
+    template<size_t I = 0, typename eT, typename... Tp>
+    void UpdateWeights(arma::Cube<eT>& weights, std::tuple<Tp...>& t)
+    {
+      if (gradientNum == std::get<I>(t).InputLayer().OutputMaps() != 1)
+      {
+        for (size_t i = 0, g = 0;
+            i < std::get<I>(t).OutputLayer().OutputMaps(); i++)
+        {
+          for (size_t j = i; j < weights.n_slices;
+              j+= std::get<I>(t).OutputLayer().OutputMaps(), g++)
+          {
+            std::get<I>(t).Optimzer().UpdateWeights(weights.slice(j),
+                gradients[gradientNum].slice(g), trainError);
+          }
+        }
+      }
+      else
+      {
+        for (size_t i = 0; i < weights.n_slices; i++)
+        {
+          std::get<I>(t).Optimzer().UpdateWeights(weights.slice(i),
+              gradients[gradientNum].slice(i), trainError);
+        }
+      }
+    }
+
+    /**
      * Update the weights using the gradients from the gradient store.
      *
      * enable_if (SFINAE) is used to iterate through the network connections.
@@ -423,8 +467,7 @@ class CNN
       if (!ConnectionTraits<typename std::remove_reference<decltype(
           std::get<I>(t))>::type>::IsPoolingConnection)
       {
-        std::get<I>(t).Optimzer().UpdateWeights(std::get<I>(t).Weights(),
-            gradients[gradientNum], trainError);
+        UpdateWeights<I>(std::get<I>(t).Weights(), t);
 
         // Reset the gradient storage.
         gradients[gradientNum++].zeros();
@@ -473,14 +516,36 @@ class CNN
       if (!ConnectionTraits<typename std::remove_reference<decltype(
           std::get<I>(t))>::type>::IsPoolingConnection)
       {
-        gradients.push_back(new DataType(std::get<I>(t).Weights().n_rows,
-            std::get<I>(t).Weights().n_cols,
-            std::get<I>(t).OutputLayer().LayerSlices(), arma::fill::zeros));
+          gradients.push_back(new DataType(std::get<I>(t).Weights().n_rows,
+              std::get<I>(t).Weights().n_cols,
+              ElementCount(std::get<I>(t).Weights()), arma::fill::zeros));
       }
 
       Layer<I + 1, Tp...>(t);
     }
 
+    /*
+     * Get the number of elements.
+     *
+     * @param data The reference data.
+     */
+    template<typename eT>
+    size_t ElementCount(const arma::Mat<eT>& /* unused */) const
+    {
+      return 1;
+    }
+
+    /*
+     * Get the number of elements.
+     *
+     * @param data The reference data.
+     */
+    template<typename eT>
+    size_t ElementCount(const arma::Cube<eT>& data) const
+    {
+      return data.n_slices;
+    }
+
     //! The connection modules used to build the network.
     ConnectionTypes network;
 



More information about the mlpack-git mailing list