[mlpack-git] master: Add auxiliary function to set the network weights using a given initialize rule. (b33964b)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Apr 9 07:31:20 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ba826b1959a3f83532e91765b2bba0705e588d39...f4b3464fce6bdc7c61d94f6b22bc71fe61276328
>---------------------------------------------------------------
commit b33964b9310ad54b1035694c737b00627d31758f
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Fri Apr 8 16:53:55 2016 +0200
Add auxiliary function to set the network weights using a given initialize rule.
>---------------------------------------------------------------
b33964b9310ad54b1035694c737b00627d31758f
src/mlpack/methods/ann/network_util.hpp | 61 ++++++++++++++++++++++++
src/mlpack/methods/ann/network_util_impl.hpp | 71 ++++++++++++++++++++++++++++
2 files changed, 132 insertions(+)
diff --git a/src/mlpack/methods/ann/network_util.hpp b/src/mlpack/methods/ann/network_util.hpp
index 193d5a8..ed268a0 100644
--- a/src/mlpack/methods/ann/network_util.hpp
+++ b/src/mlpack/methods/ann/network_util.hpp
@@ -172,6 +172,67 @@ typename std::enable_if<
HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
LayerInputSize(T& layer, P& output);
+/**
+ * Auxiliary function to set the weights of the specified network using a given
+ * initialize rule.
+ *
+ * @param initializeRule The rule used to initialize the network weights.
+ * @param weights The weights used to set the weights of the network.
+ * @param network The network used to set the weights.
+ * @param offset The memory offset of the weights.
+ */
+template<size_t I = 0, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+ arma::mat& weights,
+ std::tuple<Tp...>& network,
+ size_t offset = 0);
+
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+ arma::mat& weights,
+ std::tuple<Tp...>& network,
+ size_t offset = 0);
+
+/**
+ * Auxiliary function to set the weights of the specified layer using the given
+ * initialize rule.
+ *
+ * @param initializeRule The rule used to initialize the layer weights.
+ * @param layer The layer used to set the weights.
+ * @param weights The weights used to set the weights of the layer.
+ * @param offset The memory offset of the weights.
+ * @param output The output parameter of the layer.
+ * @return The number of weights.
+ */
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+ HasWeightsCheck<T, arma::mat&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+ T& layer,
+ arma::mat& weights,
+ size_t offset,
+ arma::mat& output);
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+ HasWeightsCheck<T, arma::cube&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+ T& layer,
+ arma::mat& weights,
+ size_t offset,
+ arma::cube& output);
+
+template<typename InitializationRuleType, typename T, typename P>
+typename std::enable_if<
+ !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+ T& layer,
+ arma::mat& weights,
+ size_t offset,
+ P& output);
+
} // namespace ann
} // namespace mlpack
diff --git a/src/mlpack/methods/ann/network_util_impl.hpp b/src/mlpack/methods/ann/network_util_impl.hpp
index 34cc84d..affd880 100644
--- a/src/mlpack/methods/ann/network_util_impl.hpp
+++ b/src/mlpack/methods/ann/network_util_impl.hpp
@@ -204,6 +204,77 @@ LayerInputSize(T& /* unused */, P& /* unused */)
return 0;
}
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+ arma::mat& weights,
+ std::tuple<Tp...>& network,
+ size_t offset)
+{
+ NetworkWeights<I + 1, InitializationRuleType, Tp...>(initializeRule, weights,
+ network, offset + LayerWeights(initializeRule, std::get<I>(network),
+ weights, offset, std::get<I>(network).OutputParameter()));
+}
+
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& /* initializeRule */,
+ arma::mat& /* weights */,
+ std::tuple<Tp...>& /* network */,
+ size_t /* offset */)
+{
+ /* Nothing to do here */
+}
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+ HasWeightsCheck<T, arma::mat&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+ T& layer,
+ arma::mat& weights,
+ size_t offset,
+ arma::mat& /* output */)
+{
+ layer.Weights() = arma::mat(weights.memptr() + offset,
+ layer.Weights().n_rows, layer.Weights().n_cols, false, false);
+
+ initializeRule.Initialize(layer.Weights(), layer.Weights().n_rows,
+ layer.Weights().n_cols);
+
+ return layer.Weights().n_elem;
+}
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+ HasWeightsCheck<T, arma::cube&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+ T& layer,
+ arma::mat& weights,
+ size_t offset,
+ arma::cube& /* output */)
+{
+ layer.Weights() = arma::cube(weights.memptr() + offset,
+ layer.Weights().n_rows, layer.Weights().n_cols,
+ layer.Weights().n_slices, false, false);
+
+ initializeRule.Initialize(layer.Weights(), layer.Weights().n_rows,
+ layer.Weights().n_cols);
+
+ return layer.Weights().n_elem;
+}
+
+template<typename InitializationRuleType, typename T, typename P>
+typename std::enable_if<
+ !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& /* initializeRule */,
+ T& /* layer */,
+ arma::mat& /* weights */,
+ size_t /* offset */,
+ P& /* output */)
+{
+ return 0;
+}
+
} // namespace ann
} // namespace mlpack
More information about the mlpack-git
mailing list