[mlpack-git] master: The gradient to update the weight depends on the number of output maps. (3ade929)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Jun 9 14:14:06 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1d5349c98666e8d592117e23ad9d6e9d0dae36f8...3ade9299e8f3c1e73ba30bff276b51813ede87b5
>---------------------------------------------------------------
commit 3ade9299e8f3c1e73ba30bff276b51813ede87b5
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Tue Jun 9 17:46:23 2015 +0200
The gradient to update the weight depends on the number of output maps.
>---------------------------------------------------------------
3ade9299e8f3c1e73ba30bff276b51813ede87b5
src/mlpack/methods/ann/cnn.hpp | 83 +++++++++++++++++++++++++++++++++++++-----
1 file changed, 74 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp
index c2a6caa..e80ae2f 100644
--- a/src/mlpack/methods/ann/cnn.hpp
+++ b/src/mlpack/methods/ann/cnn.hpp
@@ -109,8 +109,8 @@ class CNN
* @param output Output data used to store the output activation
* @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
- template <typename VecType>
- void Predict(const VecType& input, VecType& output)
+ template <typename InputType, typename OutputType>
+ void Predict(const InputType& input, OutputType& output)
{
ResetActivations(network);
@@ -181,6 +181,7 @@ class CNN
Reset(std::tuple<Tp...>& t)
{
std::get<I>(t).OutputLayer().InputActivation().zeros();
+ std::get<I>(t).Delta().zeros();
Reset<I + 1, Tp...>(t);
}
@@ -204,6 +205,7 @@ class CNN
{
ConnectionForward(std::get<I>(t));
+
// Use the first connection to perform the feed forward algorithm.
std::get<0>(std::get<I>(t)).OutputLayer().FeedForward(
std::get<0>(std::get<I>(t)).OutputLayer().InputActivation(),
@@ -254,8 +256,8 @@ class CNN
/*
* Calculate and store the output activation.
*/
- template<typename VecType, typename... Tp>
- void OutputPrediction(std::tuple<Tp...>& t, VecType& output)
+ template<typename OutputType, typename... Tp>
+ void OutputPrediction(std::tuple<Tp...>& t, OutputType& output)
{
// Calculate and store the output prediction.
outputLayer.OutputClass(std::get<0>(
@@ -402,6 +404,48 @@ class CNN
}
/**
+ * Helper function to update the weights using the gradients from the
+ * gradient store.
+ *
+ * enable_if (SFINAE) is used to select between two template overloads of
+ * the get function - one for when I is equal the size of the tuple of
+ * connections, and one for the general case which peels off the first type
+ * and recurses, as usual with variadic function templates.
+ */
+ template<size_t I = 0, typename eT, typename... Tp>
+ void UpdateWeights(arma::Mat<eT>& weights, std::tuple<Tp...>& t)
+ {
+ std::get<I>(t).Optimzer().UpdateWeights(weights,
+ gradients[gradientNum].slice(0), trainError);
+ }
+
+ template<size_t I = 0, typename eT, typename... Tp>
+ void UpdateWeights(arma::Cube<eT>& weights, std::tuple<Tp...>& t)
+ {
+ if (gradientNum == std::get<I>(t).InputLayer().OutputMaps() != 1)
+ {
+ for (size_t i = 0, g = 0;
+ i < std::get<I>(t).OutputLayer().OutputMaps(); i++)
+ {
+ for (size_t j = i; j < weights.n_slices;
+ j+= std::get<I>(t).OutputLayer().OutputMaps(), g++)
+ {
+ std::get<I>(t).Optimzer().UpdateWeights(weights.slice(j),
+ gradients[gradientNum].slice(g), trainError);
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < weights.n_slices; i++)
+ {
+ std::get<I>(t).Optimzer().UpdateWeights(weights.slice(i),
+ gradients[gradientNum].slice(i), trainError);
+ }
+ }
+ }
+
+ /**
* Update the weights using the gradients from the gradient store.
*
* enable_if (SFINAE) is used to iterate through the network connections.
@@ -423,8 +467,7 @@ class CNN
if (!ConnectionTraits<typename std::remove_reference<decltype(
std::get<I>(t))>::type>::IsPoolingConnection)
{
- std::get<I>(t).Optimzer().UpdateWeights(std::get<I>(t).Weights(),
- gradients[gradientNum], trainError);
+ UpdateWeights<I>(std::get<I>(t).Weights(), t);
// Reset the gradient storage.
gradients[gradientNum++].zeros();
@@ -473,14 +516,36 @@ class CNN
if (!ConnectionTraits<typename std::remove_reference<decltype(
std::get<I>(t))>::type>::IsPoolingConnection)
{
- gradients.push_back(new DataType(std::get<I>(t).Weights().n_rows,
- std::get<I>(t).Weights().n_cols,
- std::get<I>(t).OutputLayer().LayerSlices(), arma::fill::zeros));
+ gradients.push_back(new DataType(std::get<I>(t).Weights().n_rows,
+ std::get<I>(t).Weights().n_cols,
+ ElementCount(std::get<I>(t).Weights()), arma::fill::zeros));
}
Layer<I + 1, Tp...>(t);
}
+ /*
+ * Get the number of elements.
+ *
+ * @param data The reference data.
+ */
+ template<typename eT>
+ size_t ElementCount(const arma::Mat<eT>& /* unused */) const
+ {
+ return 1;
+ }
+
+ /*
+ * Get the number of elements.
+ *
+ * @param data The reference data.
+ */
+ template<typename eT>
+ size_t ElementCount(const arma::Cube<eT>& data) const
+ {
+ return data.n_slices;
+ }
+
//! The connection modules used to build the network.
ConnectionTypes network;
More information about the mlpack-git
mailing list