[mlpack-git] master: Add function to get the input size of a given network. (d0708a5)

gitdub at mlpack.org gitdub at mlpack.org
Sat Apr 9 07:31:20 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ba826b1959a3f83532e91765b2bba0705e588d39...f4b3464fce6bdc7c61d94f6b22bc71fe61276328

>---------------------------------------------------------------

commit d0708a5f6d437edbbb574cdac599b58eccc259ba
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Apr 8 15:43:32 2016 +0200

    Add function to get the input size of a given network.


>---------------------------------------------------------------

d0708a5f6d437edbbb574cdac599b58eccc259ba
 src/mlpack/methods/ann/network_util.hpp      | 32 +++++++++++++++++++++++
 src/mlpack/methods/ann/network_util_impl.hpp | 38 ++++++++++++++++++++++++++++
 2 files changed, 70 insertions(+)

diff --git a/src/mlpack/methods/ann/network_util.hpp b/src/mlpack/methods/ann/network_util.hpp
index c0b2227..193d5a8 100644
--- a/src/mlpack/methods/ann/network_util.hpp
+++ b/src/mlpack/methods/ann/network_util.hpp
@@ -140,6 +140,38 @@ template<typename T, typename P>
 typename std::enable_if<
     !HasGradientCheck<T, P&(T::*)()>::value, size_t>::type
 LayerGradients(T& layer, arma::mat& gradients, size_t offset, P& output);
+
+/**
+ * Auxiliary function to get the input size of the specified network.
+ *
+ * @param network The network used for specifying the input size.
+ * @return The input size.
+ */
+template<size_t I = 0, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network);
+
+template<size_t I, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network);
+
+/**
+ * Auxiliary function to get the input size of the specified layer.
+ *
+ * @param layer The layer used for specifying the input size.
+ * @param output The layer output parameter.
+ * @return The input size.
+ */
+template<typename T, typename P>
+typename std::enable_if<
+    !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& output);
+
+template<typename T, typename P>
+typename std::enable_if<
+    HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& output);
+
 } // namespace ann
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/ann/network_util_impl.hpp b/src/mlpack/methods/ann/network_util_impl.hpp
index fbb4041..34cc84d 100644
--- a/src/mlpack/methods/ann/network_util_impl.hpp
+++ b/src/mlpack/methods/ann/network_util_impl.hpp
@@ -166,6 +166,44 @@ LayerGradients(T& /* unused */,
   return 0;
 }
 
+template<size_t I, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& /* unused */)
+{
+  return 0;
+}
+
+template<size_t I, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network)
+{
+  const size_t inputSize = LayerInputSize(std::get<I>(network), std::get<I>(
+      network).OutputParameter());
+
+  if (inputSize)
+  {
+    return inputSize;
+  }
+
+  return NetworkInputSize<I + 1, Tp...>(network);
+}
+
+template<typename T, typename P>
+typename std::enable_if<
+  HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& /* unused */)
+{
+  return layer.Weights().n_cols;
+}
+
+template<typename T, typename P>
+typename std::enable_if<
+  !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& /* unused */, P& /* unused */)
+{
+  return 0;
+}
+
 } // namespace ann
 } // namespace mlpack
 




More information about the mlpack-git mailing list