[mlpack-git] master: Add batch support. In batch mode, the convolution runs on a batch of images. (296c951)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Jun 4 04:47:16 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2f479f388ee3d34e4a20535c3662b1921a4c6c06...7fb32130bd683cf03a853ea2bc6960e80d625955

>---------------------------------------------------------------

commit 296c951e2337367108ebdbcaaef56d56f9c14de1
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Wed Jun 3 22:10:14 2015 +0200

    Add batch support. In batch mode, the convolution runs on a batch of images.


>---------------------------------------------------------------

296c951e2337367108ebdbcaaef56d56f9c14de1
 .../ann/convolution_rules/fft_convolution.hpp      | 30 ++++++++++++++++++++
 .../ann/convolution_rules/naive_convolution.hpp    | 29 ++++++++++++++++++++
 .../ann/convolution_rules/svd_convolution.hpp      | 32 ++++++++++++++++++++--
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
index d247e33..3003bbe 100644
--- a/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/fft_convolution.hpp
@@ -178,6 +178,36 @@ class FFTConvolution
       output.slice(i) = convOutput;
     }
   }
+
+  /*
+   * Perform a convolution using a 3rd order tensors as input and output and a
+   * dense matrix as filter.
+   *
+   * @param input Input used to perform the convolution.
+   * @param filter Filter used to perform the conolution.
+   * @param output Output data that contains the results of the convolution.
+   */
+  template<typename eT>
+  static void Convolution(const arma::Cube<eT>& input,
+                          const arma::Mat<eT>& filter,
+                          arma::Cube<eT>& output)
+  {
+    arma::Mat<eT> convOutput;
+    FFTConvolution<BorderMode>::Convolution(input.slice(0), filter,
+        convOutput);
+
+    output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+        input.n_slices);
+    output.slice(0) = convOutput;
+
+    for (size_t i = 1; i < input.n_slices; i++)
+    {
+      FFTConvolution<BorderMode>::Convolution(input.slice(i), filter,
+          convOutput);
+      output.slice(i) = convOutput;
+    }
+  }
+
 };  // class FFTConvolution
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
index 45c75fd..cfb47a9 100644
--- a/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/naive_convolution.hpp
@@ -148,6 +148,35 @@ class NaiveConvolution
           output.slice(i));
     }
   }
+
+  /*
+   * Perform a convolution using a 3rd order tensors as input and output and a
+   * dense matrix as filter.
+   *
+   * @param input Input used to perform the convolution.
+   * @param filter Filter used to perform the conolution.
+   * @param output Output data that contains the results of the convolution.
+   */
+  template<typename eT>
+  static void Convolution(const arma::Cube<eT>& input,
+                          const arma::Mat<eT>& filter,
+                          arma::Cube<eT>& output)
+  {
+    arma::Mat<eT> convOutput;
+    NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter,
+        convOutput);
+
+    output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+        input.n_slices);
+    output.slice(0) = convOutput;
+
+    for (size_t i = 1; i < input.n_slices; i++)
+    {
+      NaiveConvolution<BorderMode>::Convolution(input.slice(i), filter,
+          output.slice(i));
+    }
+  }
+
 };  // class NaiveConvolution
 
 }; // namespace ann
diff --git a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
index 95d7997..cedd2ae 100644
--- a/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
+++ b/src/mlpack/methods/ann/convolution_rules/svd_convolution.hpp
@@ -144,8 +144,7 @@ class SVDConvolution
                           arma::Cube<eT>& output)
   {
     arma::Mat<eT> convOutput;
-    SVDConvolution<BorderMode>::Convolution(input, filter.slice(0),
-        convOutput);
+    SVDConvolution<BorderMode>::Convolution(input, filter.slice(0), convOutput);
 
     output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
         filter.n_slices);
@@ -158,6 +157,35 @@ class SVDConvolution
       output.slice(i) = convOutput;
     }
   }
+
+  /*
+   * Perform a convolution using a 3rd order tensors as input and output and a
+   * dense matrix as filter.
+   *
+   * @param input Input used to perform the convolution.
+   * @param filter Filter used to perform the conolution.
+   * @param output Output data that contains the results of the convolution.
+   */
+  template<typename eT>
+  static void Convolution(const arma::Cube<eT>& input,
+                          const arma::Mat<eT>& filter,
+                          arma::Cube<eT>& output)
+  {
+    arma::Mat<eT> convOutput;
+    SVDConvolution<BorderMode>::Convolution(input.slice(0), filter, convOutput);
+
+    output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
+        input.n_slices);
+    output.slice(0) = convOutput;
+
+    for (size_t i = 1; i < input.n_slices; i++)
+    {
+      SVDConvolution<BorderMode>::Convolution(input.slice(i), filter,
+          convOutput);
+      output.slice(i) = convOutput;
+    }
+  }
+
 };  // class SVDConvolution
 
 }; // namespace ann



More information about the mlpack-git mailing list