[mlpack-git] master: Adjust FFN class so that it works with the mlpack optimizers; Add Train() method regarding the design guidelines. (69f97e5)

gitdub at mlpack.org gitdub at mlpack.org
Fri Feb 19 08:31:43 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f6dd2f7a9752a7db8ec284a938b3e84a13d0bfb2...6205f3e0b62b56452b2a4afc4da24fce5b21e72f

>---------------------------------------------------------------

commit 69f97e50aeddfb83e708475606b74a05bee8034a
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Fri Feb 19 14:31:43 2016 +0100

    Adjust FFN class so that it works with the mlpack optimizers; Add Train() method regarding the design guidelines.


>---------------------------------------------------------------

69f97e50aeddfb83e708475606b74a05bee8034a
 src/mlpack/methods/ann/ffn.hpp      | 454 +++++++++++++++++-------------------
 src/mlpack/methods/ann/ffn_impl.hpp | 279 ++++++++++++++++++++++
 2 files changed, 493 insertions(+), 240 deletions(-)

diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp
index a612108..d4cd35d 100644
--- a/src/mlpack/methods/ann/ffn.hpp
+++ b/src/mlpack/methods/ann/ffn.hpp
@@ -9,147 +9,208 @@
 
 #include <mlpack/core.hpp>
 
-#include <mlpack/methods/ann/network_traits.hpp>
+#include <mlpack/methods/ann/network_util.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
+#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
 #include <mlpack/methods/ann/performance_functions/cee_function.hpp>
-#include <mlpack/methods/ann/performance_functions/sparse_function.hpp>
+#include <mlpack/core/optimizers/rmsprop/rmsprop.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
 
 /**
- * An implementation of a standard feed forward network.
+ * Implementation of a standard feed forward network.
  *
  * @tparam LayerTypes Contains all layer modules used to construct the network.
  * @tparam OutputLayerType The outputlayer type used to evaluate the network.
- * @tparam PerformanceFunction Performance strategy used to claculate the error.
+ * @tparam InitializationRuleType Rule used to initialize the weight matrix.
+ * @tparam PerformanceFunction Performance strategy used to calculate the error.
  */
 template <
   typename LayerTypes,
   typename OutputLayerType,
+  typename InitializationRuleType = NguyenWidrowInitialization,
   class PerformanceFunction = CrossEntropyErrorFunction<>
 >
 class FFN
 {
  public:
+  //! Convenience typedef for the internal model construction.
+  using NetworkType = FFN<LayerTypes,
+                          OutputLayerType,
+                          InitializationRuleType,
+                          PerformanceFunction>;
+
   /**
-   * Construct the FFN object, which will construct a feed forward neural
-   * network with the specified layers.
+   * Create the FFN object with the given predictors and responses set (this is
+   * the set that is used to train the network) and the given optimizer.
+   * Optionally, specify which initialize rule and performance function should
+   * be used.
    *
-   * @param network The network modules used to construct the network.
-   * @param outputLayer The outputlayer used to evaluate the network.
-   * @param performanceFunction Performance strategy used to claculate the error.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param predictors Input training variables.
+   * @param responses Outputs resulting from input training variables.
+   * @param optimizer Instantiated optimizer used to train the model.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  template<typename Layer, typename OutLayer>
-  FFN(Layer &&network, OutLayer &&outputLayer,
-      PerformanceFunction performanceFunction = PerformanceFunction())
-    : network(std::forward<Layer>(network)),
-      outputLayer(std::forward<OutLayer>(outputLayer)),
-      performanceFunc(std::move(performanceFunction)),
-      trainError(0)
-  {
-    static_assert(std::is_same<typename std::decay<Layer>::type,
-                  LayerTypes>::value,
-                  "The type of network must be LayerTypes.");
-
-    static_assert(std::is_same<typename std::decay<OutLayer>::type,
-                  OutputLayerType>::value,
-                  "The type of outputLayer must be OutputLayerType.");
-  }  
+  template<typename LayerType,
+           typename OutputType,
+           template<typename> class OptimizerType>
+  FFN(LayerType &&network,
+      OutputType &&outputLayer,
+      const arma::mat& predictors,
+      const arma::mat& responses,
+      OptimizerType<NetworkType>& optimizer,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
   /**
-   * Run a single iteration of the feed forward algorithm, using the given
-   * input and target vector, store the calculated error into the error
-   * parameter.
+   * Create the FFN object with the given predictors and responses set (this is
+   * the set that is used to train the network). Optionally, specify which
+   * initialize rule and performance function should be used.
    *
-   * @param input Input data used to evaluate the network.
-   * @param target Target data used to calculate the network error.
-   * @param error The calulated error of the output layer.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param predictors Input training variables.
+   * @param responses Outputs resulting from input training variables.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  template <typename InputType, typename TargetType, typename ErrorType>
-  void FeedForward(const InputType& input,
-                   const TargetType& target,
-                   ErrorType& error)
-  {
-    deterministic = false;
-    trainError = Evaluate(input, target, error);
-  }
+  template<typename LayerType, typename OutputType>
+  FFN(LayerType &&network,
+      OutputType &&outputLayer,
+      const arma::mat& predictors,
+      const arma::mat& responses,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
   /**
-   * Run a single iteration of the feed backward algorithm, using the given
-   * error of the output layer.
+   * Create the FNN object with an empty predictors and responses set and
+   * default optimizer. Make sure to call Train(predictors, responses) when
+   * training.
    *
-   * @param error The calulated error of the output layer.
+   * @param network Network modules used to construct the network.
+   * @param outputLayer Outputlayer used to evaluate the network.
+   * @param initializeRule Optional instantiated InitializationRule object
+   *        for initializing the network paramter.
+   * @param performanceFunction Optional instantiated PerformanceFunction
+   *        object used to claculate the error.
    */
-  template <typename InputType, typename ErrorType>
-  void FeedBackward(const InputType& /* unused */, const ErrorType& error)
-  {
-    Backward<>(error, network);
-    UpdateGradients<>(network);
-  }
+  template<typename LayerType, typename OutputType>
+  FFN(LayerType &&network,
+      OutputType &&outputLayer,
+      InitializationRuleType initializeRule = InitializationRuleType(),
+      PerformanceFunction performanceFunction = PerformanceFunction());
 
   /**
-   * Update the weights using the layer defined optimizer.
+   * Train the feedforward network on the given input data. By default, the
+   * RMSprop optimization algorithm is used, but others can be specified
+   * (such as mlpack::optimization::SGD).
+   *
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @tparam OptimizerType Type of optimizer to use to train the model.
+   * @param predictors Input training variables.
+   * @param responses Outputs results from input training variables.
    */
-  void ApplyGradients()
-  {
-    ApplyGradients<>(network);
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::RMSprop
+  >
+  void Train(const arma::mat& predictors, const arma::mat& responses);
 
-    // Reset the overall error.
-    trainError = 0;
-  }
+  /**
+   * Train the feedforward network with the given instantiated optimizer.
+   * Using this overload allows configuring the instantiated optimizer before
+   * training is performed.
+   *
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @param optimizer Instantiated optimizer used to train the model.
+   */
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::RMSprop
+  >
+  void Train(OptimizerType<NetworkType>& optimizer);
 
   /**
-   * Evaluate the network using the given input. The output activation is
-   * stored into the output parameter.
+   * Train the feedforward network on the given input data using the given
+   * optimizer.
    *
-   * @param input Input data used to evaluate the network.
-   * @param output Output data used to store the output activation
+   * This will use the existing model parameters as a starting point for the
+   * optimization. If this is not what you want, then you should access the
+   * parameters vector directly with Parameters() and modify it as desired.
+   *
+   * @tparam OptimizerType Type of optimizer to use to train the model.
+   * @param predictors Input training variables.
+   * @param responses Outputs results from input training variables.
+   * @param optimizer Instantiated optimizer used to train the model.
    */
-  template <typename DataType>
-  void Predict(const DataType& input, DataType& output)
-  {
-    deterministic = true;
-    ResetParameter(network);
+  template<
+      template<typename> class OptimizerType = mlpack::optimization::RMSprop
+  >
+  void Train(const arma::mat& predictors,
+             const arma::mat& responses,
+             OptimizerType<NetworkType>& optimizer);
 
-    Forward(input, network);
-    OutputPrediction(output, network);
-  }
+  /**
+   * Predict the responses to a given set of predictors. The responses will
+   * reflect the output of the given output layer as returned by the
+   * OutputClass() function.
+   *
+   * @param predictors Input predictors.
+   * @param responses Matrix to put output predictions of responses into.
+   */
+  void Predict(arma::mat& predictors, arma::mat& responses);
 
   /**
-   * Evaluate the trained network using the given input and compare the output
-   * with the given target vector.
+   * Evaluate the feedforward network with the given parameters. This function
+   * is usually called by the optimizer to train the model.
    *
-   * @param input Input data used to evaluate the trained network.
-   * @param target Target data used to calculate the network error.
-   * @param error The calulated error of the output layer.
+   * @param parameters Matrix model parameters.
+   * @param i Index of point to use for objective function evaluation.
+   * @param deterministic Whether or not to train or test the model. Note some
+   * layer act differently in training or testing mode.
    */
-  template <typename InputType, typename TargetType, typename ErrorType>
-  double Evaluate(const InputType& input,
-                  const TargetType& target,
-                  ErrorType& error)
-  {
-    deterministic = false;
-    ResetParameter(network);
+  double Evaluate(const arma::mat& parameters,
+                  const size_t i,
+                  const bool deterministic = false);
 
-    Forward(input, network);    
-    return OutputError(target, error, network);
-  }
+  /**
+   * Evaluate the gradient of the feedforward network with the given parameters,
+   * and with respect to only one point in the dataset. This is useful for
+   * optimizers such as SGD, which require a separable objective function.
+   *
+   * @param parameters Matrix of the model parameters to be optimized.
+   * @param i Index of points to use for objective function gradient evaluation.
+   * @param gradient Matrix to output gradient into.
+   */
+  void Gradient(const arma::mat& parameters,
+                const size_t i,
+                arma::mat& gradient);
 
-  //! Get the error of the network.
-  double Error() const { return trainError; }
+  //! Return the number of separable functions (the number of predictor points).
+  size_t NumFunctions() const { return numFunctions; }
 
-  //! Get the constructed network object.
-  LayerTypes const& Model() const { return network; }
-  //! Modify the constructed network object.
-  LayerTypes& Model() { return network; }
+  //! Return the initial point for the optimization.
+  const arma::mat& Parameters() const { return parameter; }
+  //! Modify the initial point for the optimization.
+  arma::mat& Parameters() { return parameter; }
 
-  //! Get the output layer object.
-  OutputLayerType const& OutputLayer() const { return outputLayer; }
-  //! Modify the output layer object.
-  OutputLayerType& OutputLayer() { return outputLayer; }
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
 
- private:
+private:
   /**
    * Reset the network by zeroing the layer activations and by setting the
    * layer status.
@@ -164,26 +225,22 @@ class FFN
 
   template<size_t I = 0, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  ResetParameter(std::tuple<Tp...>& t)
+  ResetParameter(std::tuple<Tp...>& network)
   {
-    ResetDeterministic(std::get<I>(t));
-    ResetParameter<I + 1, Tp...>(t);
+    ResetDeterministic(std::get<I>(network));
+    ResetParameter<I + 1, Tp...>(network);
   }
 
   /**
    * Reset the layer status by setting the current deterministic parameter
    * through all layer that implement the Deterministic function.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<typename T>
   typename std::enable_if<
       HasDeterministicCheck<T, bool&(T::*)(void)>::value, void>::type
-  ResetDeterministic(T& t)
+  ResetDeterministic(T& layer)
   {
-    t.Deterministic() = deterministic;
+    layer.Deterministic() = deterministic;
   }
 
   template<typename T>
@@ -195,46 +252,37 @@ class FFN
    * Run a single iteration of the feed forward algorithm, using the given
    * input and target vector, store the calculated error into the error
    * vector.
-   *
-   * enable_if (SFINAE) is used to select between two template overloads of
-   * the get function - one for when I is equal the size of the tuple of
-   * layer, and one for the general case which peels off the first type
-   * and recurses, as usual with variadic function templates.
    */
   template<size_t I = 0, typename DataType, typename... Tp>
-  void Forward(const DataType& input, std::tuple<Tp...>& t)
+  void Forward(const DataType& input, std::tuple<Tp...>& network)
   {
-    std::get<I>(t).InputParameter() = input;
+    std::get<I>(network).InputParameter() = input;
 
-    std::get<I>(t).Forward(std::get<I>(t).InputParameter(),
-                           std::get<I>(t).OutputParameter());
+    std::get<I>(network).Forward(std::get<I>(network).InputParameter(),
+        std::get<I>(network).OutputParameter());
     
-    ForwardTail<I + 1, Tp...>(t);
+    ForwardTail<I + 1, Tp...>(network);
   }
 
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
-  ForwardTail(std::tuple<Tp...>& /* unused */)
+  ForwardTail(std::tuple<Tp...>& network)
   {
     LinkParameter(network);
   }
 
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  ForwardTail(std::tuple<Tp...>& t)
+  ForwardTail(std::tuple<Tp...>& network)
   {
-    std::get<I>(t).Forward(std::get<I - 1>(t).OutputParameter(),
-                           std::get<I>(t).OutputParameter());
+    std::get<I>(network).Forward(std::get<I - 1>(network).OutputParameter(),
+                           std::get<I>(network).OutputParameter());
     
-    ForwardTail<I + 1, Tp...>(t);
+    ForwardTail<I + 1, Tp...>(network);
   }
 
   /**
    * Link the calculated activation with the connection layer.
-   *
-   * enable_if (SFINAE) is used to iterate through the network. The general
-   * case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I == sizeof...(Tp), void>::type
@@ -242,15 +290,16 @@ class FFN
 
   template<size_t I = 1, typename... Tp>
   typename std::enable_if<I < sizeof...(Tp), void>::type
-  LinkParameter(std::tuple<Tp...>& t)
+  LinkParameter(std::tuple<Tp...>& network)
   {
     if (!LayerTraits<typename std::remove_reference<
-        decltype(std::get<I>(t))>::type>::IsBiasLayer)
+        decltype(std::get<I>(network))>::type>::IsBiasLayer)
     {
-      std::get<I>(t).InputParameter() = std::get<I - 1>(t).OutputParameter();
+      std::get<I>(network).InputParameter() = std::get<I - 1>(
+          network).OutputParameter();
     }
 
-    LinkParameter<I + 1, Tp...>(t);
+    LinkParameter<I + 1, Tp...>(network);
   }
 
   /*
@@ -259,11 +308,11 @@ class FFN
   template<typename DataType, typename ErrorType, typename... Tp>
   double OutputError(const DataType& target,
                      ErrorType& error,
-                     const std::tuple<Tp...>& t)
+                     const std::tuple<Tp...>& network)
   {
     // Calculate and store the output error.
     outputLayer.CalculateError(
-        std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target, error);
+        std::get<sizeof...(Tp) - 1>(network).OutputParameter(), target, error);
 
     // Measures the network's performance with the specified performance
     // function.
@@ -274,21 +323,16 @@ class FFN
    * Run a single iteration of the feed backward algorithm, using the given
    * error of the output layer. Note that we iterate backward through the
    * layer modules.
-   *
-   * enable_if (SFINAE) is used to select between two template overloads of
-   * the get function - one for when I is equal the size of the tuple of
-   * layer, and one for the general case which peels off the first type
-   * and recurses, as usual with variadic function templates.
    */
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
-  Backward(const DataType& error, std::tuple<Tp ...>& t)
+  Backward(const DataType& error, std::tuple<Tp ...>& network)
   {
-    std::get<sizeof...(Tp) - I>(t).Backward(
-        std::get<sizeof...(Tp) - I>(t).OutputParameter(), error,
-        std::get<sizeof...(Tp) - I>(t).Delta());    
+    std::get<sizeof...(Tp) - I>(network).Backward(
+        std::get<sizeof...(Tp) - I>(network).OutputParameter(), error,
+        std::get<sizeof...(Tp) - I>(network).Delta());
 
-    BackwardTail<I + 1, DataType, Tp...>(error, t);
+    BackwardTail<I + 1, DataType, Tp...>(error, network);
   }
 
   template<size_t I = 1, typename DataType, typename... Tp>
@@ -298,23 +342,19 @@ class FFN
 
   template<size_t I = 1, typename DataType, typename... Tp>
   typename std::enable_if<I < (sizeof...(Tp)), void>::type
-  BackwardTail(const DataType& error, std::tuple<Tp...>& t)
+  BackwardTail(const DataType& error, std::tuple<Tp...>& network)
   {    
-    std::get<sizeof...(Tp) - I>(t).Backward(
-        std::get<sizeof...(Tp) - I>(t).OutputParameter(),
-        std::get<sizeof...(Tp) - I + 1>(t).Delta(),
-        std::get<sizeof...(Tp) - I>(t).Delta());   
+    std::get<sizeof...(Tp) - I>(network).Backward(
+        std::get<sizeof...(Tp) - I>(network).OutputParameter(),
+        std::get<sizeof...(Tp) - I + 1>(network).Delta(),
+        std::get<sizeof...(Tp) - I>(network).Delta());
 
-    BackwardTail<I + 1, DataType, Tp...>(error, t);
+    BackwardTail<I + 1, DataType, Tp...>(error, network);
   }
 
   /**
    * Iterate through all layer modules and update the the gradient using the
    * layer defined optimizer.
-   *
-   * enable_if (SFINAE) is used to iterate through the network layer.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
    */
   template<
       size_t I = 0,
@@ -330,21 +370,20 @@ class FFN
       typename... Tp
   >
   typename std::enable_if<I < Max, void>::type
-  UpdateGradients(std::tuple<Tp...>& t)
+  UpdateGradients(std::tuple<Tp...>& network)
   {   
-    Update(std::get<I>(t), std::get<I>(t).OutputParameter(),
-           std::get<I + 1>(t).Delta());    
+    Update(std::get<I>(network), std::get<I>(network).OutputParameter(),
+           std::get<I + 1>(network).Delta());
 
-    UpdateGradients<I + 1, Max, Tp...>(t);
+    UpdateGradients<I + 1, Max, Tp...>(network);
   }
 
   template<typename T, typename P, typename D>
   typename std::enable_if<
       HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
-  Update(T& t, P& /* unused */, D& delta)
+  Update(T& layer, P& /* unused */, D& delta)
   {
-    t.Gradient(delta, t.Gradient());
-    t.Optimizer().Update();
+    layer.Gradient(delta, layer.Gradient());
   }
 
   template<typename T, typename P, typename D>
@@ -355,67 +394,18 @@ class FFN
     /* Nothing to do here */
   }
 
-  /**
-   * Update the weights using the calulated gradients.
-   *
-   * enable_if (SFINAE) is used to iterate through the network connections.
-   * The general case peels off the first type and recurses, as usual with
-   * variadic function templates.
-   */
-  template<
-      size_t I = 0,
-      size_t Max = std::tuple_size<LayerTypes>::value - 1,
-      typename... Tp
-  >
-  typename std::enable_if<I == Max, void>::type
-  ApplyGradients(std::tuple<Tp...>& /* unused */)
-  {
-    /* Nothing to do here */
-  }
-
-  template<
-      size_t I = 0,
-      size_t Max = std::tuple_size<LayerTypes>::value - 1,
-      typename... Tp
-  >
-  typename std::enable_if<I < Max, void>::type
-  ApplyGradients(std::tuple<Tp...>& t)
-  {
-    Apply(std::get<I>(t), std::get<I>(t).OutputParameter(),
-          std::get<I + 1>(t).Delta());
-
-    ApplyGradients<I + 1, Max, Tp...>(t);
-  }
-
-  template<typename T, typename P, typename D>
-  typename std::enable_if<
-      HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
-  Apply(T& t, P& /* unused */, D& /* unused */)
-  {
-    t.Optimizer().Optimize();
-    t.Optimizer().Reset();
-  }
-
-  template<typename T, typename P, typename D>
-  typename std::enable_if<
-      !HasGradientCheck<T, void(T::*)(const P&, D&)>::value, void>::type
-  Apply(T& /* unused */, P& /* unused */, D& /* unused */)
-  {
-    /* Nothing to do here */
-  }
-
   /*
    * Calculate and store the output activation.
    */
   template<typename DataType, typename... Tp>
-  void OutputPrediction(DataType& output, std::tuple<Tp...>& t)
+  void OutputPrediction(DataType& output, std::tuple<Tp...>& network)
   {
     // Calculate and store the output prediction.
-    outputLayer.OutputClass(std::get<sizeof...(Tp) - 1>(t).OutputParameter(),
-        output);
+    outputLayer.OutputClass(std::get<sizeof...(Tp) - 1>(
+        network).OutputParameter(), output);
   }
 
-  //! The layer modules used to build the network.
+  //! Instantiated feedforward network.
   LayerTypes network;
 
   //! The outputlayer used to evaluate the network
@@ -424,45 +414,29 @@ class FFN
   //! Performance strategy used to claculate the error.
   PerformanceFunction performanceFunc;
 
-  //! The current training error of the network.
-  double trainError;
-
   //! The current evaluation mode (training or testing).
   bool deterministic;
-}; // class FFN
 
-//! Network traits for the FFN network.
-template <
-  typename LayerTypes,
-  typename OutputLayerType,
-  class PerformanceFunction
->
-class NetworkTraits<
-    FFN<LayerTypes, OutputLayerType, PerformanceFunction> >
-{
- public:
-  static const bool IsFNN = true;
-  static const bool IsRNN = false;
-  static const bool IsCNN = false;
-  static const bool IsSAE = false;
-};
+  //! Matrix of (trained) parameters.
+  arma::mat parameter;
 
-//! Network traits for the FFN network.
-template <
-  typename LayerTypes,
-  typename OutputLayerType
->
-class NetworkTraits<
-    FFN<LayerTypes, OutputLayerType, SparseErrorFunction<arma::mat> > >
-{
- public:
-  static const bool IsFNN = false;
-  static const bool IsRNN = false;
-  static const bool IsCNN = false;
-  static const bool IsSAE = true;
-};
+  //! The matrix of data points (predictors).
+  arma::mat predictors;
+
+  //! The matrix of responses to the input data points.
+  arma::mat responses;
+
+  //! The number of separable functions (the number of predictor points).
+  size_t numFunctions;
+
+  //! Locally stored backward error.
+  arma::mat error;
+}; // class FFN
 
 } // namespace ann
 } // namespace mlpack
 
+// Include implementation.
+#include "ffn_impl.hpp"
+
 #endif
diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/ffn_impl.hpp
new file mode 100644
index 0000000..bd7436a
--- /dev/null
+++ b/src/mlpack/methods/ann/ffn_impl.hpp
@@ -0,0 +1,279 @@
+/**
+ * @file ffn_impl.hpp
+ * @author Marcus Edel
+ *
+ * Definition of the FFN class, which implements feed forward neural networks.
+ */
+#ifndef __MLPACK_METHODS_ANN_FFN_IMPL_HPP
+#define __MLPACK_METHODS_ANN_FFN_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ffn.hpp"
+
+namespace mlpack {
+namespace ann /** Artificial Neural Network. */ {
+
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<typename LayerType,
+         typename OutputType,
+         template<typename> class OptimizerType
+>
+FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::FFN(LayerType &&network,
+       OutputType &&outputLayer,
+       const arma::mat& predictors,
+       const arma::mat& responses,
+       OptimizerType<NetworkType>& optimizer,
+       InitializationRuleType initializeRule,
+       PerformanceFunction performanceFunction) : 
+    network(std::forward<LayerType>(network)),
+    outputLayer(std::forward<OutputType>(outputLayer)),
+    performanceFunc(std::move(performanceFunction)),
+    predictors(predictors),
+    responses(responses),
+    numFunctions(predictors.n_cols)
+{
+  static_assert(std::is_same<typename std::decay<LayerType>::type,
+                  LayerTypes>::value,
+                  "The type of network must be LayerTypes.");
+
+  static_assert(std::is_same<typename std::decay<OutputType>::type,
+                OutputLayerType>::value,
+                "The type of outputLayer must be OutputLayerType.");
+
+  initializeRule.Initialize(parameter, NetworkSize(network), 1);
+  NetworkWeights(parameter, network);
+
+  // Train the model.
+  Timer::Start("ffn_optimization");
+  const double out = optimizer.Optimize(parameter);
+  Timer::Stop("ffn_optimization");
+
+  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+      << "." << std::endl;
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<typename LayerType, typename OutputType>
+FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::FFN(LayerType &&network,
+       OutputType &&outputLayer,
+       const arma::mat& predictors,
+       const arma::mat& responses,
+       InitializationRuleType initializeRule,
+       PerformanceFunction performanceFunction) : 
+    network(std::forward<LayerType>(network)),
+    outputLayer(std::forward<OutputType>(outputLayer)),
+    performanceFunc(std::move(performanceFunction))
+{
+  static_assert(std::is_same<typename std::decay<LayerType>::type,
+                  LayerTypes>::value,
+                  "The type of network must be LayerTypes.");
+
+  static_assert(std::is_same<typename std::decay<OutputType>::type,
+                OutputLayerType>::value,
+                "The type of outputLayer must be OutputLayerType.");
+
+  initializeRule.Initialize(parameter, NetworkSize(network), 1);
+  NetworkWeights(parameter, network);
+
+  Train(predictors, responses);
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<typename LayerType, typename OutputType>
+FFN<LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::FFN(LayerType &&network,
+       OutputType &&outputLayer,
+       InitializationRuleType initializeRule,
+       PerformanceFunction performanceFunction) : 
+    network(std::forward<LayerType>(network)),
+    outputLayer(std::forward<OutputType>(outputLayer)),
+    performanceFunc(std::move(performanceFunction))
+{
+  static_assert(std::is_same<typename std::decay<LayerType>::type,
+                  LayerTypes>::value,
+                  "The type of network must be LayerTypes.");
+
+  static_assert(std::is_same<typename std::decay<OutputType>::type,
+                OutputLayerType>::value,
+                "The type of outputLayer must be OutputLayerType.");
+
+  initializeRule.Initialize(parameter, NetworkSize(network), 1);
+  NetworkWeights(parameter, network);
+
+  Log::Debug << parameter << std::endl;
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<template<typename> class OptimizerType>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Train(const arma::mat& predictors, const arma::mat& responses)
+{
+  numFunctions = predictors.n_cols;
+  this->predictors = predictors;
+  this->responses = responses;
+
+  OptimizerType<decltype(*this)> optimizer(*this);
+
+  // Train the model.
+  Timer::Start("ffn_optimization");
+  const double out = optimizer.Optimize(parameter);
+  Timer::Stop("ffn_optimization");
+
+  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+      << "." << std::endl;
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<template<typename> class OptimizerType>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Train(const arma::mat& predictors,
+         const arma::mat& responses,
+         OptimizerType<NetworkType>& optimizer)
+{
+  numFunctions = predictors.n_cols;
+  this->predictors = predictors;
+  this->responses = responses;
+
+  // Train the model.
+  Timer::Start("ffn_optimization");
+  const double out = optimizer.Optimize(parameter);
+  Timer::Stop("ffn_optimization");
+
+  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+      << "." << std::endl;
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<
+    template<typename> class OptimizerType
+>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Train(OptimizerType<NetworkType>& optimizer)
+{
+  // Train the model.
+  Timer::Start("ffn_optimization");
+  const double out = optimizer.Optimize(parameter);
+  Timer::Stop("ffn_optimization");
+
+  Log::Info << "FFN::FFN(): final objective of trained model is " << out
+      << "." << std::endl;
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Predict(arma::mat& predictors, arma::mat& responses)
+{
+  deterministic = true;
+
+  arma::mat responsesTemp;
+  ResetParameter(network);
+  Forward(arma::mat(predictors.colptr(0), predictors.n_rows, 1, false, true),
+      network);
+  OutputPrediction(responsesTemp, network);
+
+  responses = arma::mat(responsesTemp.n_elem, predictors.n_cols);
+  responses.col(0) = responsesTemp.col(0);
+
+  for (size_t i = 1; i < predictors.n_cols; i++)
+  {
+    Forward(arma::mat(predictors.colptr(i), predictors.n_rows, 1, false, true),
+        network);
+
+    responsesTemp = arma::mat(responses.colptr(i), responses.n_rows, 1, false,
+        true);
+    OutputPrediction(responsesTemp, network);
+    responses.col(i) = responsesTemp.col(0);
+  }
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+double FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Evaluate(const arma::mat& /* unused */,
+            const size_t i,
+            const bool deterministic)
+{
+  this->deterministic = deterministic;
+
+  ResetParameter(network);
+
+  Forward(arma::mat(predictors.colptr(i), predictors.n_rows, 1, false, true),
+      network);
+
+  return OutputError(arma::mat(responses.colptr(i), responses.n_rows, 1, false,
+      true), error, network);
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Gradient(const arma::mat& /* unused */,
+            const size_t /* unused */,
+            arma::mat& gradient)
+{
+  NetworkGradients(gradient, network);
+
+  Backward<>(error, network);
+  UpdateGradients<>(network);
+}
+
+template<typename LayerTypes,
+         typename OutputLayerType,
+         typename InitializationRuleType,
+         typename PerformanceFunction
+>
+template<typename Archive>
+void FFN<
+LayerTypes, OutputLayerType, InitializationRuleType, PerformanceFunction
+>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  ar & data::CreateNVP(parameter, "parameter");
+}
+
+} // namespace ann
+} // namespace mlpack
+
+#endif




More information about the mlpack-git mailing list