[mlpack-git] master: Add function to get the input size of a given network. (d0708a5)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Apr 9 07:31:20 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ba826b1959a3f83532e91765b2bba0705e588d39...f4b3464fce6bdc7c61d94f6b22bc71fe61276328
>---------------------------------------------------------------
commit d0708a5f6d437edbbb574cdac599b58eccc259ba
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Fri Apr 8 15:43:32 2016 +0200
Add function to get the input size of a given network.
>---------------------------------------------------------------
d0708a5f6d437edbbb574cdac599b58eccc259ba
src/mlpack/methods/ann/network_util.hpp | 32 +++++++++++++++++++++++
src/mlpack/methods/ann/network_util_impl.hpp | 38 ++++++++++++++++++++++++++++
2 files changed, 70 insertions(+)
diff --git a/src/mlpack/methods/ann/network_util.hpp b/src/mlpack/methods/ann/network_util.hpp
index c0b2227..193d5a8 100644
--- a/src/mlpack/methods/ann/network_util.hpp
+++ b/src/mlpack/methods/ann/network_util.hpp
@@ -140,6 +140,38 @@ template<typename T, typename P>
typename std::enable_if<
!HasGradientCheck<T, P&(T::*)()>::value, size_t>::type
LayerGradients(T& layer, arma::mat& gradients, size_t offset, P& output);
+
+/**
+ * Auxiliary function to get the input size of the specified network.
+ *
+ * @param network The network used for specifying the input size.
+ * @return The input size.
+ */
+template<size_t I = 0, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network);
+
+template<size_t I, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network);
+
+/**
+ * Auxiliary function to get the input size of the specified layer.
+ *
+ * @param layer The layer used for specifying the input size.
+ * @param output The layer output parameter.
+ * @return The input size.
+ */
+template<typename T, typename P>
+typename std::enable_if<
+ !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& output);
+
+template<typename T, typename P>
+typename std::enable_if<
+ HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& output);
+
} // namespace ann
} // namespace mlpack
diff --git a/src/mlpack/methods/ann/network_util_impl.hpp b/src/mlpack/methods/ann/network_util_impl.hpp
index fbb4041..34cc84d 100644
--- a/src/mlpack/methods/ann/network_util_impl.hpp
+++ b/src/mlpack/methods/ann/network_util_impl.hpp
@@ -166,6 +166,44 @@ LayerGradients(T& /* unused */,
return 0;
}
+template<size_t I, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& /* unused */)
+{
+ return 0;
+}
+
+template<size_t I, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), size_t>::type
+NetworkInputSize(std::tuple<Tp...>& network)
+{
+ const size_t inputSize = LayerInputSize(std::get<I>(network), std::get<I>(
+ network).OutputParameter());
+
+ if (inputSize)
+ {
+ return inputSize;
+ }
+
+ return NetworkInputSize<I + 1, Tp...>(network);
+}
+
+template<typename T, typename P>
+typename std::enable_if<
+ HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& layer, P& /* unused */)
+{
+ return layer.Weights().n_cols;
+}
+
+template<typename T, typename P>
+typename std::enable_if<
+ !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerInputSize(T& /* unused */, P& /* unused */)
+{
+ return 0;
+}
+
} // namespace ann
} // namespace mlpack
More information about the mlpack-git
mailing list