[mlpack-git] master: Add batch support. In batch mode, the convolution runs on a batch of images. (296c951)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Jun 4 04:47:16 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/2f479f388ee3d34e4a20535c3662b1921a4c6c06...7fb32130bd683cf03a853ea2bc6960e80d625955
>---------------------------------------------------------------
commit 296c951e2337367108ebdbcaaef56d56f9c14de1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Wed Jun 3 22:10:14 2015 +0200
Add batch support. In batch mode, the convolution runs on a batch of images.
>---------------------------------------------------------------
296c951e2337367108ebdbcaaef56d56f9c14de1
.../ann/convolution_rules/fft_convolution.hpp | 30 ++++++++++++++++++++
.../ann/convolution_rules/naive_convolution.hpp | 29 ++++++++++++++++++++
.../ann/convolution_rules/svd_convolution.hpp | 32 ++++++++++++++++++++--
3 files changed, 89 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
index d247e33..3003bbe 100644
--- a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
@@ -178,6 +178,36 @@ class FFTConvolution
output.slice(i) = convOutput;
}
}
+
+ /*
+ * Perform a convolution using a 3rd order tensors as input and output and a
+ * dense matrix as filter.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Mat<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ FFTConvolution<BorderMode>::Convolution(input.slice(0), filter,
+ convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ FFTConvolution<BorderMode>::Convolution(input.slice(i), filter,
+ convOutput);
+ output.slice(i) = convOutput;
+ }
+ }
+
}; // class FFTConvolution
}; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
index 45c75fd..cfb47a9 100644
--- a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
@@ -148,6 +148,35 @@ class NaiveConvolution
output.slice(i));
}
}
+
+ /*
+ * Perform a convolution using a 3rd order tensors as input and output and a
+ * dense matrix as filter.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Mat<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter,
+ convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter,
+ output.slice(i));
+ }
+ }
+
}; // class NaiveConvolution
}; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
index 95d7997..cedd2ae 100644
--- a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
@@ -144,8 +144,7 @@ class SVDConvolution
arma::Cube<eT>& output)
{
arma::Mat<eT> convOutput;
- SVDConvolution<BorderMode>::Convolution(input, filter.slice(0),
- convOutput);
+ SVDConvolution<BorderMode>::Convolution(input, filter.slice(0), convOutput);
output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
filter.n_slices);
@@ -158,6 +157,35 @@ class SVDConvolution
output.slice(i) = convOutput;
}
}
+
+ /*
+ * Perform a convolution using a 3rd order tensors as input and output and a
+ * dense matrix as filter.
+ *
+ * @param input Input used to perform the convolution.
+ * @param filter Filter used to perform the conolution.
+ * @param output Output data that contains the results of the convolution.
+ */
+ template<typename eT>
+ static void Convolution(const arma::Cube<eT>& input,
+ const arma::Mat<eT>& filter,
+ arma::Cube<eT>& output)
+ {
+ arma::Mat<eT> convOutput;
+ SVDConvolution<BorderMode>::Convolution(input.slice(0), filter, convOutput);
+
+ output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+ input.n_slices);
+ output.slice(0) = convOutput;
+
+ for (size_t i = 1; i < input.n_slices; i++)
+ {
+ SVDConvolution<BorderMode>::Convolution(input.slice(i), filter,
+ convOutput);
+ output.slice(i) = convOutput;
+ }
+ }
+
}; // class SVDConvolution
}; // namespace ann
More information about the mlpack-git
mailing list