[mlpack-git] master: Add auxiliary function to set the network weights using a given initialize rule. (b33964b)

gitdub at mlpack.org gitdub at mlpack.org
Sat Apr 9 07:31:20 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ba826b1959a3f83532e91765b2bba0705e588d39...f4b3464fce6bdc7c61d94f6b22bc71fe61276328

>---------------------------------------------------------------

commit b33964b9310ad54b1035694c737b00627d31758f
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Fri Apr 8 16:53:55 2016 +0200

    Add auxiliary function to set the network weights using a given initialize rule.


>---------------------------------------------------------------

b33964b9310ad54b1035694c737b00627d31758f
 src/mlpack/methods/ann/network_util.hpp      | 61 ++++++++++++++++++++++++
 src/mlpack/methods/ann/network_util_impl.hpp | 71 ++++++++++++++++++++++++++++
 2 files changed, 132 insertions(+)

diff --git a/src/mlpack/methods/ann/network_util.hpp b/src/mlpack/methods/ann/network_util.hpp
index 193d5a8..ed268a0 100644
--- a/src/mlpack/methods/ann/network_util.hpp
+++ b/src/mlpack/methods/ann/network_util.hpp
@@ -172,6 +172,67 @@ typename std::enable_if<
     HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
 LayerInputSize(T& layer, P& output);
 
+/**
+ * Auxiliary function to set the weights of the specified network using a given
+ * initialize rule.
+ *
+ * @param initializeRule The rule used to initialize the network weights.
+ * @param weights The weights used to set the weights of the network.
+ * @param network The network used to set the weights.
+ * @param offset The memory offset of the weights.
+ */
+template<size_t I = 0, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+               arma::mat& weights,
+               std::tuple<Tp...>& network,
+               size_t offset = 0);
+
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+               arma::mat& weights,
+               std::tuple<Tp...>& network,
+               size_t offset = 0);
+
+/**
+ * Auxiliary function to set the weights of the specified layer using the given
+ * initialize rule.
+ *
+ * @param initializeRule The rule used to initialize the layer weights.
+ * @param layer The layer used to set the weights.
+ * @param weights The weights used to set the weights of the layer.
+ * @param offset The memory offset of the weights.
+ * @param output The output parameter of the layer.
+ * @return The number of weights.
+ */
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+    HasWeightsCheck<T, arma::mat&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+             T& layer,
+             arma::mat& weights,
+             size_t offset,
+             arma::mat& output);
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+    HasWeightsCheck<T, arma::cube&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+             T& layer,
+             arma::mat& weights,
+             size_t offset,
+             arma::cube& output);
+
+template<typename InitializationRuleType, typename T, typename P>
+typename std::enable_if<
+    !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+             T& layer,
+             arma::mat& weights,
+             size_t offset,
+             P& output);
+
 } // namespace ann
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/ann/network_util_impl.hpp b/src/mlpack/methods/ann/network_util_impl.hpp
index 34cc84d..affd880 100644
--- a/src/mlpack/methods/ann/network_util_impl.hpp
+++ b/src/mlpack/methods/ann/network_util_impl.hpp
@@ -204,6 +204,77 @@ LayerInputSize(T& /* unused */, P& /* unused */)
   return 0;
 }
 
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I < sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& initializeRule,
+               arma::mat& weights,
+               std::tuple<Tp...>& network,
+               size_t offset)
+{
+  NetworkWeights<I + 1, InitializationRuleType, Tp...>(initializeRule, weights,
+      network, offset + LayerWeights(initializeRule, std::get<I>(network),
+      weights, offset, std::get<I>(network).OutputParameter()));
+}
+
+template<size_t I, typename InitializationRuleType, typename... Tp>
+typename std::enable_if<I == sizeof...(Tp), void>::type
+NetworkWeights(InitializationRuleType& /* initializeRule */,
+               arma::mat& /* weights */,
+               std::tuple<Tp...>& /* network */,
+               size_t /* offset */)
+{
+  /* Nothing to do here */
+}
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+    HasWeightsCheck<T, arma::mat&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+             T& layer,
+             arma::mat& weights,
+             size_t offset,
+             arma::mat& /* output */)
+{
+  layer.Weights() = arma::mat(weights.memptr() + offset,
+      layer.Weights().n_rows, layer.Weights().n_cols, false, false);
+
+  initializeRule.Initialize(layer.Weights(), layer.Weights().n_rows,
+      layer.Weights().n_cols);
+
+  return layer.Weights().n_elem;
+}
+
+template<typename InitializationRuleType, typename T>
+typename std::enable_if<
+    HasWeightsCheck<T, arma::cube&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& initializeRule,
+             T& layer,
+             arma::mat& weights,
+             size_t offset,
+             arma::cube& /* output */)
+{
+  layer.Weights() = arma::cube(weights.memptr() + offset,
+      layer.Weights().n_rows, layer.Weights().n_cols,
+      layer.Weights().n_slices, false, false);
+
+  initializeRule.Initialize(layer.Weights(), layer.Weights().n_rows,
+      layer.Weights().n_cols);
+
+  return layer.Weights().n_elem;
+}
+
+template<typename InitializationRuleType, typename T, typename P>
+typename std::enable_if<
+    !HasWeightsCheck<T, P&(T::*)()>::value, size_t>::type
+LayerWeights(InitializationRuleType& /* initializeRule */,
+             T& /* layer */,
+             arma::mat& /* weights */,
+             size_t /* offset */,
+             P& /* output */)
+{
+  return 0;
+}
+
 } // namespace ann
 } // namespace mlpack
 




More information about the mlpack-git mailing list