[mlpack-git] master: Add 3rd order tensors support (convolution). (066d860)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Apr 22 08:04:50 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e2beb217f3e17729b73f9edc05601195e92f775d...8f85309ae9be40e819b301b39c9a940aa28f3bb2
>---------------------------------------------------------------
commit 066d860d928e9bad42b44ab5a790c625c99f028a
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Wed Apr 22 11:01:29 2015 +0200
Add 3rd order tensors support (convolution).
>---------------------------------------------------------------
066d860d928e9bad42b44ab5a790c625c99f028a
.../ann/convolution_rules/fft_convolution.hpp | 42 +++++++++++++++++++---
.../ann/convolution_rules/naive_convolution.hpp | 27 ++++++++++++++
.../ann/convolution_rules/svd_convolution.hpp | 32 +++++++++++++++--
3 files changed, 94 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
index a9b0a76..a7cd366 100644
--- a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
@@ -34,10 +34,10 @@ class FFTConvolution
public:
/*
* Perform a convolution through fft (valid mode). This method only supports
- * input which is even on the last dimension. In case of an odd input with, a
+ * input which is even on the last dimension. In case of an odd input width, a
* user can manually pad the imput or specify the padLastDim parameter which
- * takes care of padding. The filter instead can have any size. When using the
- * valid mode the filters has to be smaller than the input.
+ * takes care of the padding. The filter instead can have any size. When using
+ * the valid mode the filters has to be smaller than the input.
*
* @param input Input used to perform the convolution.
* @param filter Filter used to perform the conolution.
@@ -70,9 +70,9 @@ class FFTConvolution
/*
* Perform a convolution through fft (full mode). This method only supports
- * input which is even on the last dimension. In case of an odd input with, a
+ * input which is even on the last dimension. In case of an odd input width, a
* user can manually pad the imput or specify the padLastDim parameter which
- * takes care of padding. The filter instead can have any size.
+ * takes care of the padding. The filter instead can have any size.
*
* @param input Input used to perform the convolution.
* @param filter Filter used to perform the conolution.
@@ -114,6 +114,38 @@ class FFTConvolution
2 * (filter.n_rows - 1) + input.n_rows - 1,
2 * (filter.n_cols - 1) + input.n_cols - 1);
}
+
+ /*
+ * Perform a convolution through using fft 3rd order tensors. This method only
+ * supports input which is even on the last dimension. In case of an odd input
+ * width, a user can manually pad the imput or specify the padLastDim
+ * parameter which takes care of the padding. The filter instead can have any
+ * size.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Cube<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ FFTConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
+ convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ FFTConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
+ convOutput);
+ output.slice(i) = convOutput;
+ }
+ }
}; // class FFTConvolution
}; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
index 2275ef0..5934131 100644
--- a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
@@ -93,6 +93,33 @@ class NaiveConvolution
NaiveConvolution<ValidConvolution>::Convolution(inputPadded, filter,
output);
}
+
+ /*
+ * Perform a convolution using 3rd order tensors.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Cube<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
+ convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
+ output.slice(i));
+ }
+ }
}; // class NaiveConvolution
}; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
index d2352de..3a49684 100644
--- a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
@@ -38,8 +38,8 @@ class SVDConvolution
* decomposition. By using singular value decomposition of the filter matrix
* the convolution can be expressed as a sum of outer products. Each product
* can be computed efficiently as convolution with a row and a column vector.
- * The individual convolutions are computed with the naive implementation wich
- * is fast if the filter is low-dimensional.
+ * The individual convolutions are computed with the naive implementation
+ * which is fast if the filter is low-dimensional.
*
* @param input Input used to perform the convolution.
* @param filter Filter used to perform the conolution.
@@ -101,6 +101,34 @@ class SVDConvolution
}
}
}
+
+ /*
+ * Perform a convolution using 3rd order tensors.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Cube<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ SVDConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
+ convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ SVDConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
+ convOutput);
+ output.slice(i) = convOutput;
+ }
+ }
}; // class SVDConvolution
}; // namespace ann
More information about the mlpack-git
mailing list