[mlpack-git] master: Refactor to support 3rd order tensors. (e958e66)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sat Jun 6 11:16:38 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7fb32130bd683cf03a853ea2bc6960e80d625955...b5fbcaa319689553f44f2d33e5303c2a28e031e1
>---------------------------------------------------------------
commit e958e664a2753fc62554797d9f04787185929339
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Thu Jun 4 21:36:08 2015 +0200
Refactor to support 3rd order tensors.
>---------------------------------------------------------------
e958e664a2753fc62554797d9f04787185929339
.../methods/ann/connections/full_connection.hpp | 94 ++++++++++++++++------
1 file changed, 69 insertions(+), 25 deletions(-)
diff --git a/src/mlpack/methods/ann/connections/full_connection.hpp b/src/mlpack/methods/ann/connections/full_connection.hpp
index c45bed4..0937b71 100644
--- a/src/mlpack/methods/ann/connections/full_connection.hpp
+++ b/src/mlpack/methods/ann/connections/full_connection.hpp
@@ -4,8 +4,8 @@
*
* Implementation of the full connection class.
*/
-#ifndef __MLPACK_METHOS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
-#define __MLPACK_METHOS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
+#ifndef __MLPACK_METHODS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
+#define __MLPACK_METHODS_ANN_CONNECTIONS_FULL_CONNECTION_HPP
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
@@ -24,15 +24,13 @@ namespace ann /** Artificial Neural Network. */ {
* @tparam OptimizerType Type of the optimizer used to update the weights.
* @tparam WeightInitRule Rule used to initialize the weight matrix.
* @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
template<
typename InputLayerType,
typename OutputLayerType,
typename OptimizerType = SteepestDescent<>,
class WeightInitRule = NguyenWidrowInitialization,
- typename MatType = arma::mat,
- typename VecType = arma::colvec
+ typename MatType = arma::mat
>
class FullConnection
{
@@ -59,7 +57,9 @@ class FullConnection
ownsOptimizer(false)
{
weightInitRule.Initialize(weights, outputLayer.InputSize(),
- inputLayer.OutputSize() * inputLayer.LayerSlices());
+ inputLayer.LayerRows() * inputLayer.LayerCols() *
+ inputLayer.LayerSlices() * inputLayer.OutputMaps() /
+ outputLayer.LayerCols());
}
/**
@@ -82,7 +82,9 @@ class FullConnection
ownsOptimizer(true)
{
weightInitRule.Initialize(weights, outputLayer.InputSize(),
- inputLayer.OutputSize() * inputLayer.LayerSlices());
+ inputLayer.LayerRows() * inputLayer.LayerCols() *
+ inputLayer.LayerSlices() * inputLayer.OutputMaps() /
+ outputLayer.LayerCols());
}
/**
@@ -117,10 +119,19 @@ class FullConnection
template<typename eT>
void FeedForward(const arma::Cube<eT>& input)
{
- // Vectorise the input (cube of n slices with a 1x1 dense matrix) and
- // perform the feed forward pass.
- outputLayer.InputActivation() += (weights *
- arma::vec(input.memptr(), input.n_slices));
+ MatType data(input.n_elem / outputLayer.LayerCols(),
+ outputLayer.LayerCols());
+
+ for (size_t s = 0, c = 0; s < input.n_slices / data.n_cols; s++)
+ {
+ for (size_t i = 0; i < data.n_cols; i++, c++)
+ {
+ data.col(i).subvec(s * input.n_rows * input.n_cols, (s + 1) *
+ input.n_rows * input.n_cols - 1) = arma::vectorise(input.slice(c));
+ }
+ }
+
+ outputLayer.InputActivation() += (weights * data);
}
/**
@@ -130,11 +141,9 @@ class FullConnection
*
* @param error The backpropagated error.
*/
- template<typename eT>
- void FeedBackward(const arma::Col<eT>& error)
+ template<typename ErrorType>
+ void FeedBackward(const ErrorType& error)
{
- // Calculating the delta using the partial derivative of the error with
- // respect to a weight.
delta = (weights.t() * error);
}
@@ -159,13 +168,7 @@ class FullConnection
template<typename eT>
void Gradient(arma::Cube<eT>& gradient)
{
- gradient = arma::Cube<eT>(weights.n_rows, weights.n_cols, 1);
-
- // Vectorise the input (cube of n slices with a 1x1 dense matrix) and
- // calculate the gradient.
- gradient.slice(0) = outputLayer.Delta() *
- arma::rowvec(inputLayer.InputActivation().memptr(),
- inputLayer.InputActivation().n_elem);
+ GradientDelta(outputLayer.Delta(), gradient);
}
//! Get the weights.
@@ -189,11 +192,52 @@ class FullConnection
OptimizerType& Optimzer() { return *optimizer; }
//! Get the detla.
- VecType& Delta() const { return delta; }
+ MatType& Delta() const { return delta; }
// //! Modify the delta.
- VecType& Delta() { return delta; }
+ MatType& Delta() { return delta; }
private:
+ /*
+ * Calculate the gradient using the output delta (3rd oder tensor) and the
+ * input activation (3rd oder tensor).
+ *
+ * @param gradient The calculated gradient.
+ */
+ template<typename eT>
+ void GradientDelta(arma::Mat<eT>& delta, arma::Cube<eT>& gradient)
+ {
+ gradient = arma::Cube<eT>(weights.n_rows, weights.n_cols, 1);
+ arma::Mat<eT> data = arma::Mat<eT>(outputLayer.Delta().n_cols,
+ inputLayer.InputActivation().n_elem / outputLayer.Delta().n_cols);
+
+ for (size_t s = 0, c = 0; s < inputLayer.InputActivation().n_slices /
+ data.n_rows; s++)
+ {
+ for (size_t i = 0; i < data.n_rows; i++, c++)
+ {
+ data.row(i).subvec(s * inputLayer.InputActivation().n_rows *
+ inputLayer.InputActivation().n_cols, (s + 1) *
+ inputLayer.InputActivation().n_rows *
+ inputLayer.InputActivation().n_cols - 1) = arma::vectorise(
+ inputLayer.InputActivation().slice(c), 1);
+ }
+ }
+
+ gradient.slice(0) = outputLayer.Delta() * data / outputLayer.Delta().n_cols;
+ }
+
+ /*
+ * Calculate the gradient using the output delta (3rd oder tensor) and the
+ * input activation (3rd oder tensor).
+ *
+ * @param gradient The calculated gradient.
+ */
+ template<typename eT>
+ void GradientDelta(arma::Cube<eT>& delta, arma::Cube<eT>& gradient)
+ {
+ gradient = arma::Cube<eT>(weights.n_rows, weights.n_cols, 1);
+ }
+
//! Locally-stored weight object.
MatType weights;
@@ -210,7 +254,7 @@ class FullConnection
bool ownsOptimizer;
//! Locally-stored detla object that holds the calculated delta.
- VecType delta;
+ MatType delta;
}; // class FullConnection
}; // namespace ann
More information about the mlpack-git
mailing list