[mlpack-git] master: Refactor convolutional network main class for new network API. (4a7b633)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Sep 3 08:35:33 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/91ae1062772a0f2bbca9a072769629c2d775ae64...42d61dfdbc9b0cbce59398e67ea58544b0fa1919

>---------------------------------------------------------------

commit 4a7b633435431f9e38fd0b6c96f0aac7d73dcd12
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Thu Sep 3 13:59:17 2015 +0200

    Refactor convolutional network main class for new network API.


>---------------------------------------------------------------

4a7b633435431f9e38fd0b6c96f0aac7d73dcd12
 src/mlpack/methods/ann/cnn.hpp                  | 431 ++++++++++--------------
 src/mlpack/tests/convolutional_network_test.cpp |   4 +-
 2 files changed, 177 insertions(+), 258 deletions(-)

diff --git a/src/mlpack/methods/ann/cnn.hpp b/src/mlpack/methods/ann/cnn.hpp
index b3832e4..7122261 100644
--- a/src/mlpack/methods/ann/cnn.hpp
+++ b/src/mlpack/methods/ann/cnn.hpp
@@ -10,12 +10,9 @@
 
 #include <mlpack/core.hpp>
 
-#include <boost/ptr_container/ptr_vector.hpp>
-
 #include <mlpack/methods/ann/network_traits.hpp>
-#include <mlpack/methods/ann/performance_functions/cee_function.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
-#include <mlpack/methods/ann/connections/connection_traits.hpp>
+#include <mlpack/methods/ann/performance_functions/cee_function.hpp>
 
 namespace mlpack {
 namespace ann /** Artificial Neural Network. */ {
@@ -23,57 +20,47 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * An implementation of a standard convolutional network.
  *
- * @tparam ConnectionTypes Tuple that contains all connection module which will
- * be used to construct the network.
+ * @tparam LayerTypes Contains all layer modules used to construct the network.
  * @tparam OutputLayerType The outputlayer type used to evaluate the network.
  * @tparam PerformanceFunction Performance strategy used to claculate the error.
- * @tparam MaType Type of the gradients. (arma::cube).
  */
 template <
-  typename ConnectionTypes,
+  typename LayerTypes,
   typename OutputLayerType,
-  class PerformanceFunction = CrossEntropyErrorFunction<>,
-  typename DataType = arma::cube
+  class PerformanceFunction = CrossEntropyErrorFunction<>
 >
 class CNN
 {
   public:
     /**
-     * Construct the CNN object, which will construct a convolutional neural
+     * Construct the CNN object, which will construct a feed forward neural
      * network with the specified layers.
      *
      * @param network The network modules used to construct the network.
      * @param outputLayer The outputlayer used to evaluate the network.
      */
-    CNN(const ConnectionTypes& network, OutputLayerType& outputLayer)
-        : network(network), outputLayer(outputLayer), trainError(0), seqNum(0)
+    CNN(const LayerTypes& network, OutputLayerType& outputLayer)
+        : network(network), outputLayer(outputLayer), trainError(0)
     {
       // Nothing to do here.
     }
 
     /**
      * Run a single iteration of the feed forward algorithm, using the given
-     * input and target vector, updating the resulting error into the error
-     * vector.
+     * input and target vector, store the calculated error into the error
+     * parameter.
      *
      * @param input Input data used to evaluate the network.
      * @param target Target data used to calculate the network error.
      * @param error The calulated error of the output layer.
-     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
-    template <typename InputType, typename OutputType, typename ErrorType>
+    template <typename InputType, typename TargetType, typename ErrorType>
     void FeedForward(const InputType& input,
-                     const OutputType& target,
+                     const TargetType& target,
                      ErrorType& error)
     {
       deterministic = false;
-      ResetActivations(network);
-      seqNum++;
-
-      std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
-
-      LayerForward(network);
-      trainError += OutputError(network, target, error);
+      trainError += Evaluate(input, target, error);
     }
 
     /**
@@ -82,16 +69,15 @@ class CNN
      *
      * @param error The calulated error of the output layer.
      */
-    template <typename ErrorType>
-    void FeedBackward(const ErrorType& error)
+    template <typename InputType, typename ErrorType>
+    void FeedBackward(const InputType& /* unused */, const ErrorType& error)
     {
-      LayerBackward(network, error);
+      Backward(error, network);
       UpdateGradients(network);
     }
 
     /**
-     * Updating the weights using the specified optimizer.
-     *
+     * Update the weights using the layer defined optimizer.
      */
     void ApplyGradients()
     {
@@ -99,7 +85,6 @@ class CNN
 
       // Reset the overall error.
       trainError = 0;
-      seqNum = 0;
     }
 
     /**
@@ -108,18 +93,15 @@ class CNN
      *
      * @param input Input data used to evaluate the network.
      * @param output Output data used to store the output activation
-     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
-    template <typename InputType, typename OutputType>
-    void Predict(const InputType& input, OutputType& output)
+    template <typename InputDataType, typename OutputDataType>
+    void Predict(const InputDataType& input, OutputDataType& output)
     {
       deterministic = true;
-      ResetActivations(network);
-
-      std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
+      ResetParameter(network);
 
-      LayerForward(network);
-      OutputPrediction(network, output);
+      Forward(input, network);
+      OutputPrediction(output, network);
     }
 
     /**
@@ -129,20 +111,17 @@ class CNN
      * @param input Input data used to evaluate the trained network.
      * @param target Target data used to calculate the network error.
      * @param error The calulated error of the output layer.
-     * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
      */
-    template <typename InputType, typename OutputType, typename ErrorType>
+    template <typename InputType, typename TargetType, typename ErrorType>
     double Evaluate(const InputType& input,
-                    const OutputType& target,
+                    const TargetType& target,
                     ErrorType& error)
     {
-      deterministic = true;
-      ResetActivations(network);
-
-      std::get<0>(std::get<0>(network)).InputLayer().InputActivation() = input;
+      deterministic = false;
+      ResetParameter(network);
 
-      LayerForward(network);
-      return OutputError(network, target, error);
+      Forward(input, network);
+      return OutputError(target, error, network);
     }
 
     //! Get the error of the network.
@@ -150,312 +129,256 @@ class CNN
 
   private:
     /**
-     * Helper function to reset the network by zeroing the layer activations.
+     * Reset the network by setting the layer status.
      *
-     * enable_if (SFINAE) is used to iterate through the network connection
-     * modules. The general case peels off the first type and recurses, as usual
-     * with variadic function templates.
+     * enable_if (SFINAE) is used to iterate through the network. The general
+     * case peels off the first type and recurses, as usual with
+     * variadic function templates.
      */
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    ResetActivations(std::tuple<Tp...>& /* unused */) { }
+    ResetParameter(std::tuple<Tp...>& /* unused */) { /* Nothing to do here */ }
 
     template<size_t I = 0, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    ResetActivations(std::tuple<Tp...>& t)
+    ResetParameter(std::tuple<Tp...>& t)
     {
-      Reset(std::get<I>(t));
-      ResetActivations<I + 1, Tp...>(t);
+      ResetDeterministic(std::get<I>(t));
+      ResetParameter<I + 1, Tp...>(t);
     }
 
     /**
-     * Reset the network by zeroing the layer activations.
+     * Reset the layer status by setting the current deterministic parameter
+     * through all layer that implement the Deterministic function.
      *
-     * enable_if (SFINAE) is used to iterate through the network connections.
-     * The general case peels off the first type and recurses, as usual with
+     * enable_if (SFINAE) is used to iterate through the network. The general
+     * case peels off the first type and recurses, as usual with
      * variadic function templates.
      */
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    Reset(std::tuple<Tp...>& /* unused */) { }
-
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    Reset(std::tuple<Tp...>& t)
+    template<typename T>
+    typename std::enable_if<
+        HasDeterministicCheck<T, bool&(T::*)(void)>::value, void>::type
+    ResetDeterministic(T& t)
     {
-      std::get<I>(t).OutputLayer().Deterministic() = deterministic;
-      std::get<I>(t).OutputLayer().InputActivation().zeros();
-      std::get<I>(t).Delta().zeros();
-      Reset<I + 1, Tp...>(t);
+      t.Deterministic() = deterministic;
     }
 
+    template<typename T>
+    typename std::enable_if<
+        not HasDeterministicCheck<T, bool&(T::*)(void)>::value, void>::type
+    ResetDeterministic(T& /* unused */) { /* Nothing to do here */ }
+
     /**
      * Run a single iteration of the feed forward algorithm, using the given
-     * input and target vector, updating the resulting error into the error
+     * input and target vector, store the calculated error into the error
      * vector.
      *
      * enable_if (SFINAE) is used to select between two template overloads of
      * the get function - one for when I is equal the size of the tuple of
-     * connections, and one for the general case which peels off the first type
+     * layer, and one for the general case which peels off the first type
      * and recurses, as usual with variadic function templates.
      */
-    template<size_t I = 0, typename... Tp>
+    template<size_t I = 0, typename DataType, typename... Tp>
+    void Forward(const DataType& input, std::tuple<Tp...>& t)
+    {
+      std::get<I>(t).InputParameter() = input;
+
+      std::get<I>(t).Forward(std::get<I>(t).InputParameter(),
+          std::get<I>(t).OutputParameter());
+
+      ForwardTail<I + 1, Tp...>(t);
+    }
+
+    template<size_t I = 1, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    LayerForward(std::tuple<Tp...>& /* unused */) { }
+    ForwardTail(std::tuple<Tp...>& /* unused */)
+    {
+      LinkParameter(network);
+    }
 
-    template<size_t I = 0, typename... Tp>
+    template<size_t I = 1, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    LayerForward(std::tuple<Tp...>& t)
+    ForwardTail(std::tuple<Tp...>& t)
     {
-      ConnectionForward(std::get<I>(t));
-
+      std::get<I>(t).Forward(std::get<I - 1>(t).OutputParameter(),
+          std::get<I>(t).OutputParameter());
 
-      // Use the first connection to perform the feed forward algorithm.
-      std::get<0>(std::get<I>(t)).OutputLayer().FeedForward(
-          std::get<0>(std::get<I>(t)).OutputLayer().InputActivation(),
-          std::get<0>(std::get<I>(t)).OutputLayer().InputActivation());
-
-      LayerForward<I + 1, Tp...>(t);
+      ForwardTail<I + 1, Tp...>(t);
     }
 
     /**
-     * Sum up all layer activations by evaluating all connections.
+     * Link the calculated activation with the connection layer.
      *
-     * enable_if (SFINAE) is used to iterate through the network connections.
-     * The general case peels off the first type and recurses, as usual with
+     * enable_if (SFINAE) is used to iterate through the network. The general
+     * case peels off the first type and recurses, as usual with
      * variadic function templates.
      */
-    template<size_t I = 0, typename... Tp>
+    template<size_t I = 1, typename... Tp>
     typename std::enable_if<I == sizeof...(Tp), void>::type
-    ConnectionForward(std::tuple<Tp...>& /* unused */) { }
+    LinkParameter(std::tuple<Tp...>& /* unused */) { /* Nothing to do here */ }
 
-    template<size_t I = 0, typename... Tp>
+    template<size_t I = 1, typename... Tp>
     typename std::enable_if<I < sizeof...(Tp), void>::type
-    ConnectionForward(std::tuple<Tp...>& t)
+    LinkParameter(std::tuple<Tp...>& t)
     {
-      std::get<I>(t).FeedForward(std::get<I>(t).InputLayer().InputActivation());
-      ConnectionForward<I + 1, Tp...>(t);
+      if (!LayerTraits<typename std::remove_reference<
+          decltype(std::get<I>(t))>::type>::IsBiasLayer)
+      {
+        std::get<I>(t).InputParameter() = std::get<I - 1>(t).OutputParameter();
+      }
+
+      LinkParameter<I + 1, Tp...>(t);
     }
 
     /*
      * Calculate the output error and update the overall error.
      */
-    template<typename OutputType, typename ErrorType, typename... Tp>
-    double OutputError(std::tuple<Tp...>& t,
-                      const OutputType& target,
-                      ErrorType& error)
+    template<typename DataType, typename ErrorType, typename... Tp>
+    double OutputError(const DataType& target,
+                       ErrorType& error,
+                       const std::tuple<Tp...>& t)
     {
-       // Calculate and store the output error.
-      outputLayer.CalculateError(std::get<0>(
-          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
-          target, error);
+      // Calculate and store the output error.
+      outputLayer.CalculateError(
+          std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target, error);
 
       // Masures the network's performance with the specified performance
       // function.
-      return PerformanceFunction::Error(std::get<0>(
-          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
-          target);
-    }
-
-    /*
-     * Calculate and store the output activation.
-     */
-    template<typename OutputType, typename... Tp>
-    void OutputPrediction(std::tuple<Tp...>& t, OutputType& output)
-    {
-       // Calculate and store the output prediction.
-      outputLayer.OutputClass(std::get<0>(
-          std::get<sizeof...(Tp) - 1>(t)).OutputLayer().InputActivation(),
-          output);
+      return PerformanceFunction::Error(
+          std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target);
     }
 
     /**
      * Run a single iteration of the feed backward algorithm, using the given
      * error of the output layer. Note that we iterate backward through the
-     * connection modules.
+     * layer modules.
      *
      * enable_if (SFINAE) is used to select between two template overloads of
      * the get function - one for when I is equal the size of the tuple of
-     * connections, and one for the general case which peels off the first type
+     * layer, and one for the general case which peels off the first type
      * and recurses, as usual with variadic function templates.
      */
-    template<size_t I = 0, typename VecType, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    LayerBackward(std::tuple<Tp...>& /* unused */, VecType& /* unused */)
-    { }
-
-    template<size_t I = 1, typename VecType, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    LayerBackward(std::tuple<Tp...>& t, VecType& error)
+    template<size_t I = 1, typename DataType, typename... Tp>
+    typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
+    Backward(const DataType& error, std::tuple<Tp...>& t)
     {
-      // Distinguish between the output layer and the other layer. In case of
-      // the output layer use specified error vector to store the error and to
-      // perform the feed backward pass.
-      if (I == 1)
-      {
-        // Use the first connection from the last connection module to
-        // calculate the error.
-        std::get<0>(std::get<sizeof...(Tp) - I>(t)).OutputLayer().FeedBackward(
-            std::get<0>(
-            std::get<sizeof...(Tp) - I>(t)).OutputLayer().InputActivation(),
-            error, std::get<0>(
-            std::get<sizeof...(Tp) - I>(t)).OutputLayer().Delta());
-      }
-
-      ConnectionBackward(std::get<sizeof...(Tp) - I>(t), std::get<0>(
-          std::get<sizeof...(Tp) - I>(t)).OutputLayer().Delta());
+      std::get<sizeof...(Tp) - I>(t).Backward(
+          std::get<sizeof...(Tp) - I>(t).OutputParameter(), error,
+          std::get<sizeof...(Tp) - I>(t).Delta());
 
-      LayerBackward<I + 1, VecType, Tp...>(t, error);
+      BackwardTail<I + 1, DataType, Tp...>(error, t);
     }
 
-    /**
-     * Back propagate the given error and store the delta in the connection
-     * between the corresponding layer.
-     *
-     * enable_if (SFINAE) is used to iterate through the network connections.
-     * The general case peels off the first type and recurses, as usual with
-     * variadic function templates.
-     */
-    template<size_t I = 0, typename VecType, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    ConnectionBackward(std::tuple<Tp...>& /* unused */, VecType& /* unused */) { }
+    template<size_t I = 1, typename DataType, typename... Tp>
+    typename std::enable_if<I == (sizeof...(Tp)), void>::type
+    BackwardTail(const DataType& /* unused */,
+                 std::tuple<Tp...>& /* unused */) { }
 
-    template<size_t I = 0, typename VecType, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    ConnectionBackward(std::tuple<Tp...>& t, VecType& error)
+    template<size_t I = 1, typename DataType, typename... Tp>
+    typename std::enable_if<I < (sizeof...(Tp)), void>::type
+    BackwardTail(const DataType& error, std::tuple<Tp...>& t)
     {
-      std::get<I>(t).FeedBackward(error);
-
-      // We calculate the delta only for non bias layer.
-      if (!LayerTraits<typename std::remove_reference<decltype(
-          std::get<I>(t).InputLayer())>::type>::IsBiasLayer)
-      {
-        std::get<I>(t).InputLayer().FeedBackward(
-            std::get<I>(t).InputLayer().InputActivation(),
-            std::get<I>(t).Delta(), std::get<I>(t).InputLayer().Delta());
-      }
+      std::get<sizeof...(Tp) - I>(t).Backward(
+          std::get<sizeof...(Tp) - I>(t).OutputParameter(),
+          std::get<sizeof...(Tp) - I + 1>(t).Delta(),
+          std::get<sizeof...(Tp) - I>(t).Delta());
 
-      ConnectionBackward<I + 1, VecType, Tp...>(t, error);
+      BackwardTail<I + 1, DataType, Tp...>(error, t);
     }
 
     /**
-     * Helper function to iterate through all connection modules and to update
-     * the gradient storage.
+     * Iterate through all layer modules and update the the gradient using the
+     * layer defined optimizer.
      *
-     * enable_if (SFINAE) is used to select between two template overloads of
-     * the get function - one for when I is equal the size of the tuple of
-     * connections, and one for the general case which peels off the first type
-     * and recurses, as usual with variadic function templates.
+     * enable_if (SFINAE) is used to iterate through the network layer.
+     * The general case peels off the first type and recurses, as usual with
+     * variadic function templates.
      */
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
+    typename std::enable_if<I == (sizeof...(Tp) - 1), void>::type
     UpdateGradients(std::tuple<Tp...>& /* unused */) { }
 
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
+    typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
     UpdateGradients(std::tuple<Tp...>& t)
     {
-      Gradients(std::get<I>(t));
+      Update(std::get<I>(t), std::get<I>(t).OutputParameter(),
+          std::get<I + 1>(t).Delta());
+
       UpdateGradients<I + 1, Tp...>(t);
     }
 
+    template<typename T, typename P, typename D>
+    typename std::enable_if<
+        HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
+    Update(T& t, P& /* unused */, D& delta)
+    {
+      t.Gradient(delta, t.Gradient());
+      t.Optimizer().Update();
+    }
+
+    template<typename T, typename P, typename D>
+    typename std::enable_if<
+        not HasGradientCheck<T, void(T::*)(const P&, D&)>::value, void>::type
+    Update(T& /* unused */, P& /* unused */, D& /* unused */)
+    {
+      /* Nothing to do here */
+    }
+
     /**
-     * Sum up all gradients and store the results in the gradients storage.
+     * Update the weights using the calulated gradients.
      *
      * enable_if (SFINAE) is used to iterate through the network connections.
      * The general case peels off the first type and recurses, as usual with
      * variadic function templates.
      */
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    Gradients(std::tuple<Tp...>& /* unused */) { }
-
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    Gradients(std::tuple<Tp...>& t)
+    typename std::enable_if<I == (sizeof...(Tp) - 1), void>::type
+    ApplyGradients(std::tuple<Tp...>& /* unused */)
     {
-      if (!ConnectionTraits<typename std::remove_reference<decltype(
-          std::get<I>(t))>::type>::IsPoolingConnection &&
-          !ConnectionTraits<typename std::remove_reference<decltype(
-          std::get<I>(t))>::type>::IsIdentityConnection)
-      {
-        std::get<I>(t).Optimizer().Update();
-      }
-
-      Gradients<I + 1, Tp...>(t);
+      /* Nothing to do here */
     }
 
-    /**
-     * Helper function to update the weights using the specified optimizer and
-     * the given input.
-     *
-     * enable_if (SFINAE) is used to select between two template overloads of
-     * the get function - one for when I is equal the size of the tuple of
-     * connections, and one for the general case which peels off the first type
-     * and recurses, as usual with variadic function templates.
-     */
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    ApplyGradients(std::tuple<Tp...>& /* unused */) { }
-
     template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
+    typename std::enable_if<I < (sizeof...(Tp) - 1), void>::type
     ApplyGradients(std::tuple<Tp...>& t)
     {
-      Apply(std::get<I>(t));
+      Apply(std::get<I>(t), std::get<I>(t).OutputParameter(),
+          std::get<I + 1>(t).Delta());
+
       ApplyGradients<I + 1, Tp...>(t);
     }
 
-    /**
-     * Update the weights using the gradients from the gradient store.
-     *
-     * enable_if (SFINAE) is used to iterate through the network connections.
-     * The general case peels off the first type and recurses, as usual with
-     * variadic function templates.
-     */
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    Apply(std::tuple<Tp...>& /* unused */) { }
-
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    Apply(std::tuple<Tp...>& t)
+    template<typename T, typename P, typename D>
+    typename std::enable_if<
+        HasGradientCheck<T, void(T::*)(const D&, P&)>::value, void>::type
+    Apply(T& t, P& /* unused */, D& /* unused */)
     {
-      if (!ConnectionTraits<typename std::remove_reference<decltype(
-          std::get<I>(t))>::type>::IsPoolingConnection &&
-          !ConnectionTraits<typename std::remove_reference<decltype(
-            std::get<I>(t))>::type>::IsIdentityConnection)
-      {
-        std::get<I>(t).Optimizer().Optimize();
-        std::get<I>(t).Optimizer().Reset();
-      }
+      t.Optimizer().Optimize();
+      t.Optimizer().Reset();
+    }
 
-      Apply<I + 1, Tp...>(t);
+    template<typename T, typename P, typename D>
+    typename std::enable_if<
+        not HasGradientCheck<T, void(T::*)(const P&, D&)>::value, void>::type
+    Apply(T& /* unused */, P& /* unused */, D& /* unused */)
+    {
+      /* Nothing to do here */
     }
 
-    /**
-     * Helper function to iterate through all connection modules and to build
-     * gradient storage.
-     *
-     * enable_if (SFINAE) is used to select between two template overloads of
-     * the get function - one for when I is equal the size of the tuple of
-     * connections, and one for the general case which peels off the first type
-     * and recurses, as usual with variadic function templates.
+    /*
+     * Calculate and store the output activation.
      */
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I == sizeof...(Tp), void>::type
-    InitLayer(std::tuple<Tp...>& /* unused */) { }
-
-    template<size_t I = 0, typename... Tp>
-    typename std::enable_if<I < sizeof...(Tp), void>::type
-    InitLayer(std::tuple<Tp...>& t)
+    template<typename DataType, typename... Tp>
+    void OutputPrediction(DataType& output, std::tuple<Tp...>& t)
     {
-      Layer(std::get<I>(t));
-      InitLayer<I + 1, Tp...>(t);
+       // Calculate and store the output prediction.
+      outputLayer.OutputClass(std::get<sizeof...(Tp) - 1>(t).OutputParameter(),
+          output);
     }
 
-    //! The connection modules used to build the network.
-    ConnectionTypes network;
+    //! The layer modules used to build the network.
+    LayerTypes network;
 
     //! The outputlayer used to evaluate the network
     OutputLayerType& outputLayer;
@@ -463,22 +386,18 @@ class CNN
     //! The current training error of the network.
     double trainError;
 
-    //! The number of the current input sequence.
-    size_t seqNum;
-
     //! The current evaluation mode (training or testing).
     bool deterministic;
 }; // class CNN
 
-
 //! Network traits for the CNN network.
 template <
-  typename ConnectionTypes,
+  typename LayerTypes,
   typename OutputLayerType,
   class PerformanceFunction
 >
 class NetworkTraits<
-    CNN<ConnectionTypes, OutputLayerType, PerformanceFunction> >
+    CNN<LayerTypes, OutputLayerType, PerformanceFunction> >
 {
  public:
   static const bool IsFNN = false;
diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp
index 4d105ab..4c25f74 100644
--- a/src/mlpack/tests/convolutional_network_test.cpp
+++ b/src/mlpack/tests/convolutional_network_test.cpp
@@ -203,10 +203,10 @@ void BuildVanillaDropoutNetwork()
   CNN<decltype(modules), decltype(outputLayer)>
       net(modules, outputLayer);
 
-  Trainer<decltype(net)> trainer(net, 100, 1, 0.3);
+  Trainer<decltype(net)> trainer(net, 100, 1, 0.7);
   trainer.Train(input, Y, input, Y);
 
-  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.3);
+  BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7);
 }
 
 /**



More information about the mlpack-git mailing list