[mlpack-git] master: Use template parameters for the empty layer class (placeholder) to be usefull for all layer and minor style fixes. (11b4b5e)

gitdub at mlpack.org gitdub at mlpack.org
Wed Mar 23 12:09:53 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7199297dd05a1a8dbc6525bdd7fcd13559596e6b...11b4b5e99199a2f360eba220ed0abe183fdae410

>---------------------------------------------------------------

commit 11b4b5e99199a2f360eba220ed0abe183fdae410
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Wed Mar 23 17:09:53 2016 +0100

    Use template parameters for the empty layer class (placeholder) to be usefull for all layer and minor style fixes.


>---------------------------------------------------------------

11b4b5e99199a2f360eba220ed0abe183fdae410
 src/mlpack/methods/ann/layer/dropconnect_layer.hpp | 611 ++++++++++-----------
 src/mlpack/methods/ann/layer/empty_layer.hpp       |  98 +++-
 2 files changed, 354 insertions(+), 355 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/dropconnect_layer.hpp b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
index e1faccd..2c9a651 100644
--- a/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
+++ b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
@@ -5,392 +5,349 @@
  * Definition of the DropConnectLayer class, which implements a regularizer 
  * that randomly sets connections to zero. Preventing units from co-adapting.
  */
-#include "empty_layer.hpp"
 #ifndef __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
 #define __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
 
+#include <mlpack/core.hpp>
+
+#include "empty_layer.hpp"
+#include <mlpack/methods/ann/network_util.hpp>
+
 namespace mlpack {
-namespace ann/** Artificial Neural Network. */ {
-  /**
-   *  The DropConnect layer is a regularizer that randomly with probability
-   *  ratio sets the connection values to zero and scales the remaining 
-   *  elements by factor 1 /(1 - ratio). The output is scaled with 1 / (1 - p)
-   *  when deterministic is false. In the deterministic mode(during testing), 
-   *  the layer just computes the output. The output is computed according
-   *  to the input layer. If no input layer is given, it will take a
-   *  linear layer as default.
-   *
-   *  Note:
-   *  During training you should set deterministic to false and during
-   *  testing you should set deterministic to true.
-   *
-   *  For more information, see the following.
-   *  @inproceedings{icml2013_wan13,
-   *  Publisher = {JMLR Workshop and Conference Proceedings},
-   *  Title = {Regularization of Neural Networks using DropConnect},
-   *  Url = {http: // jmlr.org / proceedings / papers / v28 / wan13.pdf},
-   *  Booktitle = {Proceedings of the 30th International Conference on Machine
-   *  Learning(ICML - 13)},
-   *  Author = {Li Wan and Matthew Zeiler and Sixin Zhang and Yann L. Cun and 
-   *  Rob Fergus},
-   *  Number = {3},
-   *  Month = may,
-   *  Volume = {28},
-   *  Editor = {Sanjoy Dasgupta and David Mcallester},
-   *  Year = {2013},
-   *  Pages = {1058 - 1066},
-   *  Abstract = {We introduce DropConnect, a generalization of DropOut, for 
-   *  regularizing large fully - connected layers within neural networks.When
-   *  training with Dropout, a randomly selected subset of activations are set
-   *  to zero within each layer. DropConnect instead sets a  randomly selected
-   *  subset of weights within the network to zero. Each unit thus receives 
-   *  input from a random subset of units in the previous layer. We derive a
-   *  bound on the generalization performance of both Dropout and DropConnect.
-   *  We then evaluate DropConnect on a range of datasets, comparing to Dropout, 
-   *  and show state - of - the - art results on several image recoginition 
-   *  benchmarks can be obtained by aggregating multiple DropConnect - 
-   *  trained models.}
-*}
-*/
+namespace ann /** Artificial Neural Network. */ {
 
+/**
+ * The DropConnect layer is a regularizer that randomly with probability
+ * ratio sets the connection values to zero and scales the remaining
+ * elements by factor 1 /(1 - ratio). The output is scaled with 1 / (1 - p)
+ * when deterministic is false. In the deterministic mode(during testing),
+ * the layer just computes the output. The output is computed according
+ * to the input layer. If no input layer is given, it will take a linear layer
+ * as default.
+ *
+ * Note:
+ * During training you should set deterministic to false and during testing
+ * you should set deterministic to true.
+ *
+ *  For more information, see the following.
+ *
+ * @code
+ * @inproceedings{WanICML2013,
+ *   title={Regularization of Neural Networks using DropConnect},
+ *   booktitle = {Proceedings of the 30th International Conference on Machine
+ *                Learning(ICML - 13)},
+ *   author = {Li Wan and Matthew Zeiler and Sixin Zhang and Yann L. Cun and
+ *             Rob Fergus},
+ *   year = {2013}
+ * }
+ * @endcode
+ *
+ * @tparam InputLayer Layer used instead of the internel linear layer.
+ * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ */
 template<
-          typename InputLayer = EmptyLayer,
-          typename InputDataType = arma::mat,
-          typename OutputDataType = arma::mat
+    typename InputLayer = EmptyLayer<arma::mat, arma::mat>,
+    typename InputDataType = arma::mat,
+    typename OutputDataType = arma::mat
 >
-  class DropConnectLayer {
-    public:
-   /**
-     * Creates the DropConnect Layer as a Linear Object that takes input size and
-     * output size parameter.
-     *
-     * @param inSize The number of input units.
-     * @param outSize The number of output units.
-     */
-    DropConnectLayer (const size_t inSize, const size_t outSize,
-		      const double ratio = 0.5):
+class DropConnectLayer
+{
+ public:
+ /**
+   * Creates the DropConnect Layer as a Linear Object that takes input size,
+   * output size and ratio as parameter.
+   *
+   * @param inSize The number of input units.
+   * @param outSize The number of output units.
+   * @param ratio The probability of setting a value to zero.
+   */
+  DropConnectLayer (const size_t inSize,
+                    const size_t outSize,
+                    const double ratio = 0.5) :
       inSize(inSize),
-      outSize(outSize)
-    {
-        scale = 1.0/(1.0 - ratio);
-        uselayer = false;
-        weights.set_size(outSize, inSize);
-    }
+      outSize(outSize),
+      ratio(ratio),
+      scale(1.0 / (1 - ratio)),
+      uselayer(false)
+  {
+    weights.set_size(outSize, inSize);
+  }
 
-    /**
-     * Create the DropConnectLayer object using the specified ratio and rescale
-     * parameter. This takes the
-     *
-     * @param ratio The probability of setting a connection to zero.
-     * @param inputLayer the layer object that the dropconnect connection would take.
-     */
-    template<typename InputLayerType>
-    DropConnectLayer(InputLayerType &&inputLayer,
-                     const double ratio = 0.5) :
-            baseLayer(std::forward<InputLayerType>(inputLayer)),
-            ratio(ratio),
-            scale(1.0/(1 - ratio)),
-            uselayer(true)
-    {
-        static_assert(std::is_same<typename std::decay<InputLayerType>::type,
-                      InputLayer>::value,
-                      "The type of network must be LayerType");
-    }
-    /**
-    * Ordinary feed forward pass of the DropConnect layer.
-    *
-    * @param input Input data used for evaluating the specified function.
-    * @param output Resulting output activation.
-    */
-    template<typename eT>
-    void Forward(const arma::Mat <eT> &input, arma::Mat <eT> &output) {
-      // The DropConnect mask will not be multiplied in the deterministic mode
-      // (during testing).
-      if(uselayer) {
-        if (deterministic)
-        {
-          baseLayer.Forward(input, output);
-        }
-        else {
-          // Scale with input / (1 - ratio) and set values to zero with probability
-          // ratio.
-          mask = arma::randu < arma::Mat <eT> > (baseLayer.Weights().n_rows, baseLayer.Weights().n_cols);
-          mask.transform([&](double val) { return (val > ratio); });
-
-          // Save weights for denoising.
-          denoise = baseLayer.Weights();
-
-          baseLayer.Weights() = baseLayer.Weights() % mask;
-
-          baseLayer.Forward(input, output);
-        }
-      }
-      else{
-        if(deterministic)
-        {
-          output = weights * input;
-        }
-        else {
-          // Scale the input / ( 1 - ratio) and set values to zero with probability ratio
-          mask = arma::randu < arma::Mat <eT> > (weights.n_rows, weights.n_cols);
-          mask.transform([&](double val) { return (val > ratio); });
-
-          // Save weights for denoising.
-          denoise = weights;
-          weights = weights % mask;
-          output = weights * input;
-        }
-
-      }
-      output = output * scale;
-
-    }
-
-    /**
-     * Ordinary feed backward pass of the DropConnect layer.
-     *
-     * @param input The propagated input activation.
-     * @param gy The backpropagated error.
-     * @param g The calculated gradient.
-     */
-    template<typename DataType>
-    void Backward(const DataType & input,
-                  const DataType &gy,
-                  DataType &g)
+  /**
+   * Create the DropConnectLayer object using the specified ratio and rescale
+   * parameter. This takes the
+   *
+   * @param ratio The probability of setting a connection to zero.
+   * @param inputLayer the layer object that the dropconnect connection would take.
+   */
+  template<typename InputLayerType>
+  DropConnectLayer(InputLayerType &&inputLayer,
+                   const double ratio = 0.5) :
+      baseLayer(std::forward<InputLayerType>(inputLayer)),
+      ratio(ratio),
+      scale(1.0 / (1 - ratio)),
+      uselayer(true)
+  {
+    static_assert(std::is_same<typename std::decay<InputLayerType>::type,
+                  InputLayer>::value,
+                  "The type of the inputLayer must be InputLayerType");
+  }
+  /**
+  * Ordinary feed forward pass of the DropConnect layer.
+  *
+  * @param input Input data used for evaluating the specified function.
+  * @param output Resulting output activation.
+  */
+  template<typename eT>
+  void Forward(const arma::Mat<eT> &input, arma::Mat<eT> &output)
+  {
+    // The DropConnect mask will not be multiplied in the deterministic mode
+    // (during testing).
+    if (deterministic)
     {
       if(uselayer)
       {
-        baseLayer.Backward(input, gy, g);
+        baseLayer.Forward(input, output);
       }
       else
       {
-        g = weights.t() * gy;
+        output = weights * input;
       }
     }
-
-    /**
-     * Calculate the gradient using the output delta and the input activation.
-     * @param d The calculated error.
-     * @param g The calculated gradient.
-     */
-    template<typename eT, typename GradientDataType>
-    void Gradient(const arma::Mat<eT>& d, GradientDataType& g)
+    else
     {
       if(uselayer)
       {
-        baseLayer.Gradient(d, g);
+        // Scale with input / (1 - ratio) and set values to zero with
+        // probability ratio.
+        mask = arma::randu<arma::Mat<eT> >(baseLayer.Weights().n_rows,
+            baseLayer.Weights().n_cols);
+        mask.transform([&](double val) { return (val > ratio); });
+
+        // Save weights for denoising.
+        denoise = baseLayer.Weights();
+
+        baseLayer.Weights() = baseLayer.Weights() % mask;
 
-        // Denoise the weights.
-        baseLayer.Weights() = denoise;
+        baseLayer.Forward(input, output);
       }
       else
       {
-        g = d * inputParameter.t();
+        // Scale the input / ( 1 - ratio) and set values to zero with
+        // probability ratio.
+        mask = arma::randu<arma::Mat<eT> >(weights.n_rows, weights.n_cols);
+        mask.transform([&](double val) { return (val > ratio); });
 
-	// Denoise the weights.
-        weights = denoise;
-      }
-    }
+        // Save weights for denoising.
+        denoise = weights;
 
-    //! Get the weights.
-      OutputDataType const& Weights() const 
-      { 
-	if(uselayer)
-        {
-	  return baseLayer.Weights(); 
-        }
-	else{
-	  return weights;
-	}
+        weights = weights % mask;
+        output = weights * input;
       }
 
-    //! Modify the weights.
-    OutputDataType& Weights() 
-     {
-       if(uselayer)
-       {
-          return baseLayer.Weights();
-       }
-       else{
-	  return weights;
-       }
-     }
-    
-    //! Get the input parameter.
-    InputDataType &InputParameter() const 
-    {
-      if(uselayer)
-      {
-	  return baseLayer.InputParameter();
-      }
-      else
-      {
-	  return inputParameter;
-      }
+      output = output * scale;
     }
+  }
 
-    //! Modify the input parameter.
-    InputDataType &InputParameter() 
+  /**
+   * Ordinary feed backward pass of the DropConnect layer.
+   *
+   * @param input The propagated input activation.
+   * @param gy The backpropagated error.
+   * @param g The calculated gradient.
+   */
+  template<typename DataType>
+  void Backward(const DataType& input, const DataType& gy, DataType& g)
+  {
+    if(uselayer)
     {
-       if(uselayer)
-      {
- 	  return baseLayer.InputParameter();
-      }
-      else
-      {
-	  return inputParameter;
-      }
+      baseLayer.Backward(input, gy, g);
     }
-
-    //! Get the output parameter.
-    OutputDataType &OutputParameter() const 
+    else
     {
-      if(uselayer)
-      {
-	 return baseLayer.OutputParameter();
-      }
-      else
-      {
-	 return outputParameter;
-      }
+      g = weights.t() * gy;
     }
+  }
 
-    //! Modify the output parameter.
-    OutputDataType &OutputParameter()
+  /**
+   * Calculate the gradient using the output delta and the input activation.
+   *
+   * @param d The calculated error.
+   * @param g The calculated gradient.
+   */
+  template<typename eT, typename GradientDataType>
+  void Gradient(const arma::Mat<eT>& d, GradientDataType& g)
+  {
+    if(uselayer)
     {
-      if(uselayer)
-      {
-	return baseLayer.OutputParameter();
-      }
-      else
-      {
-	return outputParameter;
-      }
+      baseLayer.Gradient(d, g);
+
+      // Denoise the weights.
+      baseLayer.Weights() = denoise;
     }
-    //! Get the delta.
-    OutputDataType const& Delta() const 
+    else
     {
-      if(uselayer)
-      {
-        return baseLayer.Delta();
-      }
-      else
-      {
-        return delta;
-      }
-    }
+      g = d * inputParameter.t();
 
-    //! Modify the delta.
-    OutputDataType& Delta()
-    {
-      if(uselayer)
-      {
-        return baseLayer.Delta();
-      }
-      else
-      {
-        return delta;
-      }
+      // Denoise the weights.
+      weights = denoise;
     }
+  }
+
+  //! Get the weights.
+  OutputDataType const& Weights() const
+  {
+    if(uselayer)
+      return baseLayer.Weights();
+
+    return weights;
+  }
+
+  //! Modify the weights.
+  OutputDataType& Weights()
+  {
+    if(uselayer)
+      return baseLayer.Weights();
+
+    return weights;
+  }
+
+  //! Get the input parameter.
+  InputDataType &InputParameter() const
+  {
+    if(uselayer)
+      return baseLayer.InputParameter();
+
+    return inputParameter;
+  }
+
+  //! Modify the input parameter.
+  InputDataType &InputParameter()
+  {
+    if(uselayer)
+      return baseLayer.InputParameter();
 
-     //! Get the gradient.
-     OutputDataType const& Gradient() const
-     { 
-       if(uselayer)
-       {
-	 return baseLayer.Gradient(); 
-       }
-       else
-       {
-	 return gradient;
-       }
-     }
-
-    //! Modify the gradient.
-    OutputDataType& Gradient()
-    {
-       if(uselayer)
-       {
-	 return baseLayer.Gradient(); 
-       }
-       else
-       {
-	 return gradient;
-       }
-    }
+    return inputParameter;
+  }
 
-    //! The value of the deterministic parameter.
-    bool Deterministic() const { return deterministic; }
+  //! Get the output parameter.
+  OutputDataType &OutputParameter() const
+  {
+    if(uselayer)
+      return baseLayer.OutputParameter();
 
-    //! Modify the value of the deterministic parameter.
-    bool &Deterministic() { return deterministic; }
+    return outputParameter;
+  }
 
-    //! The probability of setting a value to zero.
-    double Ratio() const { return ratio; }
+  //! Modify the output parameter.
+  OutputDataType &OutputParameter()
+  {
+    if(uselayer)
+      return baseLayer.OutputParameter();
 
-    //! Modify the probability of setting a value to zero.
-    void Ratio(const double r) {
-      ratio = r;
-      scale = 1.0 / (1.0 - ratio);
-    }
-    //! Locally stored number of input units.
-    size_t inSize;
+    return outputParameter;
+  }
 
-    //! Locally-stored number of output units.
-    size_t outSize;
+  //! Get the delta.
+  OutputDataType const& Delta() const
+  {
+    if(uselayer)
+      return baseLayer.Delta();
 
-    //! Locally-stored weight object.
-    OutputDataType weights;
+    return delta;
+  }
 
-    //! Locally-stored delta object.
-    OutputDataType delta;
+  //! Modify the delta.
+  OutputDataType& Delta()
+  {
+    if(uselayer)
+      return baseLayer.Delta();
 
-    //! Locally-stored layer object.
-    InputLayer baseLayer;
+    return delta;
+  }
 
-    //! Locally-stored gradient object.
-    OutputDataType gradient;
+  //! Get the gradient.
+  OutputDataType const& Gradient() const
+  {
+    if(uselayer)
+      return baseLayer.Gradient();
 
-    //! Locally-stored input parameter object.
-    InputDataType inputParameter;
+    return gradient;
+   }
 
-    //! Locally-stored output parameter object.
-    OutputDataType outputParameter;
+  //! Modify the gradient.
+  OutputDataType& Gradient()
+  {
+    if(uselayer)
+      return baseLayer.Gradient();
     
-    //! Locally-stored mast object.
-    OutputDataType mask;
+    return gradient;
+  }
 
-    //! The probability of setting a value to zero.
-    double ratio;
+  //! The value of the deterministic parameter.
+  bool Deterministic() const { return deterministic; }
 
-    //! The scale fraction.
-    double scale;
+  //! Modify the value of the deterministic parameter.
+  bool &Deterministic() { return deterministic; }
 
-    //! If true dropout and scaling is disabled, see notes above.
-    bool deterministic;
+  //! The probability of setting a value to zero.
+  double Ratio() const { return ratio; }
 
-    //! If true the default layer is used otherwise a new layer will be created.
-    bool uselayer;
+  //! Modify the probability of setting a value to zero.
+  void Ratio(const double r)
+  {
+    ratio = r;
+    scale = 1.0 / (1.0 - ratio);
+  }
 
-    //! Denoise mask for the weights.
-    OutputDataType denoise;
-  }; // class DropConnectLayer.
-//! Layer Traits for the DropConnectLayer
-template <
-  typename InputLayer,
-  typename InputDataType,
-  typename OutputDataType
->
-class LayerTraits<DropConnectLayer<InputLayer, InputDataType, OutputDataType> >
-{
- public:
-  static const bool IsBinary = false;
-  static const bool IsOutputLayer = false;
-  static const bool IsBiasLayer = false;
-  static const bool IsLSTMLayer = false;
-  static const bool IsConnection = true;
-};
+private:
+  //! Locally stored number of input units.
+  size_t inSize;
+
+  //! Locally-stored number of output units.
+  size_t outSize;
+
+  //! The probability of setting a value to zero.
+  double ratio;
+
+  //! If true the default layer is used otherwise a new layer will be created.
+  bool uselayer;
+
+  //! Locally-stored weight object.
+  OutputDataType weights;
+
+  //! Locally-stored delta object.
+  OutputDataType delta;
+
+  //! Locally-stored layer object.
+  InputLayer baseLayer;
+
+  //! Locally-stored gradient object.
+  OutputDataType gradient;
+
+  //! Locally-stored input parameter object.
+  InputDataType inputParameter;
+
+  //! Locally-stored output parameter object.
+  OutputDataType outputParameter;
+
+  //! Locally-stored mast object.
+  OutputDataType mask;
+
+  //! The scale fraction.
+  double scale;
+
+  //! If true dropout and scaling is disabled, see notes above.
+  bool deterministic;
+
+  //! Denoise mask for the weights.
+  OutputDataType denoise;
+}; // class DropConnectLayer.
 
 }  // namespace ann
 }  // namespace mlpack
+
 #endif
diff --git a/src/mlpack/methods/ann/layer/empty_layer.hpp b/src/mlpack/methods/ann/layer/empty_layer.hpp
index 1450a52..7eb58ec 100644
--- a/src/mlpack/methods/ann/layer/empty_layer.hpp
+++ b/src/mlpack/methods/ann/layer/empty_layer.hpp
@@ -9,9 +9,20 @@
 
 namespace mlpack{
 namespace ann /** Artificial Neural Network. */ {
+
 /**
- * Definition of an empty layer class which does absolutely nothing.
+ * Implementation of the EmptyLayer class. The EmptyLayer class represents a
+ * single layer which is mainly used as placeholder.
+ *
+ * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
+ * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
+ *         arma::sp_mat or arma::cube).
  */
+template <
+    typename InputDataType = arma::mat,
+    typename OutputDataType = arma::mat
+>
 class EmptyLayer
 {
   public:
@@ -19,63 +30,94 @@ class EmptyLayer
    * Creates the empty layer object. All the methods are
    * empty as well.
    */
-  EmptyLayer()
-  {
-    // nothing to do here.
-  }
-  template<typename eT>
-  void Forward(const arma::Mat<eT>&, arma::Mat<eT>&)
+  EmptyLayer() { /* Nothing to do here. */ }
+
+  /**
+   * Ordinary feed forward pass of a neural network, evaluating the function
+   * f(x) by propagating the activity forward through f.
+   *
+   * @param input Input data used for evaluating the specified function.
+   * @param output Resulting output activation.
+   */
+  template<typename InputType, typename OutputType>
+  void Forward(const InputType& /* input */, OutputType& /* output */)
   {
-    // nothing to do here.
+    /* Nothing to do here. */
   }
 
-  template<typename InputType, typename eT>
-  void Backward(const InputType&,/* unused */
-                const arma::Mat<eT>&,
-		arma::Mat<eT>&)
+  /**
+   * Ordinary feed backward pass of a neural network, calculating the function
+   * f(x) by propagating x backwards trough f. Using the results from the feed
+   * forward pass.
+   *
+   * @param input The propagated input activation.
+   * @param gy The backpropagated error.
+   * @param g The calculated gradient.
+   */
+  template<typename InputType, typename ErrorType, typename GradientType>
+  void Backward(const InputType& /* input */,
+                const ErrorType& /* gy */,
+                GradientType& /* g */)
   {
-    // nothing to do here.
+    /* Nothing to do here. */
   }
 
-  template<typename eT, typename GradientDataType>
-  void Gradient(const arma::Mat<eT>&, GradientDataType&)
+  /*
+   * Calculate the gradient using the output delta and the input activation.
+   *
+   * @param d The calculated error.
+   * @param g The calculated gradient.
+   */
+  template<typename ErrorType, typename GradientType>
+  void Gradient(const ErrorType& /* d */, GradientType& /* g */)
   {
-    // nothing to do here.
+    /* Nothing to do here. */
   }
 
   //! Get the weights.
-  arma::mat const& Weights() const { return random; }
+  OutputDataType const& Weights() const { return weights; }
   
   //! Modify the weights.
-  arma::mat& Weights() { return random; }
+  OutputDataType& Weights() { return weights; }
   
   //! Get the input parameter.
-  arma::mat const& InputParameter() const { return random; }
+  InputDataType const& InputParameter() const { return inputParameter; }
   
   //! Modify the input parameter.
-  arma::mat& InputParameter() { return random; }
+  InputDataType& InputParameter() { return inputParameter; }
 
   //! Get the output parameter.
-  arma::mat const& OutputParameter() const { return random; }
+  OutputDataType const& OutputParameter() const { return outputParameter; }
 
   //! Modify the output parameter.
-  arma::mat& OutputParameter() { return random; }
+  OutputDataType& OutputParameter() { return outputParameter; }
 
   //! Get the delta.
-  arma::mat const& Delta() const { return random; }
+  OutputDataType const& Delta() const { return delta; }
   
   //! Modify the delta.
-  arma::mat& Delta() { return random; }
+  OutputDataType& Delta() { return delta; }
 
   //! Get the gradient.
-  arma::mat const& Gradient() const { return random; }
+  OutputDataType const& Gradient() const { return gradient; }
 
   //! Modify the gradient.
-  arma::mat& Gradient() { return random; } 
+  OutputDataType& Gradient() { return gradient; }
+  
+  //! Locally-stored weight object.
+  OutputDataType weights;
+
+  //! Locally-stored delta object.
+  OutputDataType delta;
+
+  //! Locally-stored gradient object.
+  OutputDataType gradient;
 
-  //! something random.
-  arma::mat random;
+  //! Locally-stored input parameter object.
+  InputDataType inputParameter;
 
+  //! Locally-stored output parameter object.
+  OutputDataType outputParameter;
 }; // class EmptyLayer
 
 } //namespace ann




More information about the mlpack-git mailing list