[mlpack-git] master: Adding complete implementation of dropconnect layer with all the suggested changes. (63a7f62)

gitdub at mlpack.org gitdub at mlpack.org
Wed Mar 23 09:47:24 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7199297dd05a1a8dbc6525bdd7fcd13559596e6b...11b4b5e99199a2f360eba220ed0abe183fdae410

>---------------------------------------------------------------

commit 63a7f623d799f40e4743f04d7043928c582eb1e2
Author: palashahuja <abhor902 at gmail.com>
Date:   Wed Mar 16 13:03:12 2016 +0530

    Adding complete implementation of dropconnect layer
    with all the suggested changes.


>---------------------------------------------------------------

63a7f623d799f40e4743f04d7043928c582eb1e2
 src/mlpack/methods/ann/layer/CMakeLists.txt        |   2 +
 src/mlpack/methods/ann/layer/dropconnect_layer.hpp | 404 +++++++++++++++++++++
 src/mlpack/methods/ann/layer/empty_layer.hpp       |  84 +++++
 src/mlpack/tests/feedforward_network_test.cpp      | 207 +++++++++++
 4 files changed, 697 insertions(+)

diff --git a/src/mlpack/methods/ann/layer/CMakeLists.txt b/src/mlpack/methods/ann/layer/CMakeLists.txt
index 04e41df..b639cda 100644
--- a/src/mlpack/methods/ann/layer/CMakeLists.txt
+++ b/src/mlpack/methods/ann/layer/CMakeLists.txt
@@ -4,8 +4,10 @@ set(SOURCES
   layer_traits.hpp
   binary_classification_layer.hpp
   base_layer.hpp
+  empty_layer.hpp
   bias_layer.hpp
   dropout_layer.hpp
+  dropconnect_layer.hpp
   hard_tanh_layer.hpp
   leaky_relu_layer.hpp
   linear_layer.hpp
diff --git a/src/mlpack/methods/ann/layer/dropconnect_layer.hpp b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
new file mode 100644
index 0000000..01175e9
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
@@ -0,0 +1,404 @@
+/**
+ * @file dropconnect_layer.hpp
+ * @author Palash Ahuja
+ *
+ * Definition of the DropConnectLayer class, which implements a regularizer 
+ * that randomly sets connections to zero. Preventing units from co-adapting.
+ */
+#include "empty_layer.hpp"
+#ifndef __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
+
+namespace mlpack {
+namespace ann/** Artificial Neural Network. */ {
+  /**
+   *  The DropConnect layer is a regularizer that randomly with probability
+   *  ratio sets the connection values to zero and scales the remaining 
+   *  elements by factor 1 /(1 - ratio). The output is scaled with 1 / (1 - p)
+   *  when deterministic is false. In the deterministic mode(during testing), 
+   *  the layer just computes the output. The output is computed according
+   *  to the input layer. If no input layer is given, it will take a
+   *  linear layer as default.
+   *
+   *  Note:
+   *  During training you should set deterministic to false and during
+   *  testing you should set deterministic to true.
+   *
+   *  For more information, see the following.
+   *  @inproceedings{icml2013_wan13,
+   *  Publisher = {JMLR Workshop and Conference Proceedings},
+   *  Title = {Regularization of Neural Networks using DropConnect},
+   *  Url = {http: // jmlr.org / proceedings / papers / v28 / wan13.pdf},
+   *  Booktitle = {Proceedings of the 30th International Conference on Machine
+   *  Learning(ICML - 13)},
+   *  Author = {Li Wan and Matthew Zeiler and Sixin Zhang and Yann L. Cun and 
+   *  Rob Fergus},
+   *  Number = {3},
+   *  Month = may,
+   *  Volume = {28},
+   *  Editor = {Sanjoy Dasgupta and David Mcallester},
+   *  Year = {2013},
+   *  Pages = {1058 - 1066},
+   *  Abstract = {We introduce DropConnect, a generalization of DropOut, for 
+   *  regularizing large fully - connected layers within neural networks.When
+   *  training with Dropout, a randomly selected subset of activations are set
+   *  to zero within each layer. DropConnect instead sets a  randomly selected
+   *  subset of weights within the network to zero. Each unit thus receives 
+   *  input from a random subset of units in the previous layer. We derive a
+   *  bound on the generalization performance of both Dropout and DropConnect.
+   *  We then evaluate DropConnect on a range of datasets, comparing to Dropout, 
+   *  and show state - of - the - art results on several image recoginition 
+   *  benchmarks can be obtained by aggregating multiple DropConnect - 
+   *  trained models.}
+*}
+*/
+
+template<
+          typename InputLayer = EmptyLayer,
+          typename InputDataType = arma::mat,
+          typename OutputDataType = arma::mat
+>
+  class DropConnectLayer {
+    public:
+   /**
+     * Creates the DropConnect Layer as a Linear Object that takes input size and
+     * output size parameter.
+     *
+     * @param inSize The number of input units.
+     * @param outSize The number of output units.
+     */
+    DropConnectLayer (const size_t inSize, const size_t outSize,
+		      const double ratio = 0.5):
+      inSize(inSize),
+      outSize(outSize)
+    {
+        scale = 1.0/(1.0 - ratio);
+        uselayer = false;
+        weights.set_size(outSize, inSize);
+    }
+
+    /**
+     * Create the DropConnectLayer object using the specified ratio and rescale
+     * parameter. This takes the
+     *
+     * @param ratio The probability of setting a connection to zero.
+     * @param inputLayer the layer object that the dropconnect connection would take.
+     */
+    template<typename InputLayerType>
+    DropConnectLayer(InputLayerType &&inputLayer,
+                     const double ratio = 0.5) :
+            baseLayer(std::forward<InputLayerType>(inputLayer)),
+            ratio(ratio),
+            scale(1.0/(1 - ratio)),
+            uselayer(true)
+    {
+        static_assert(std::is_same<typename std::decay<InputLayerType>::type,
+                      InputLayer>::value,
+                      "The type of network must be LayerType");
+    }
+    /**
+    * Ordinary feed forward pass of the DropConnect layer.
+    *
+    * @param input Input data used for evaluating the specified function.
+    * @param output Resulting output activation.
+    */
+    template<typename eT>
+    void Forward(const arma::Mat <eT> &input, arma::Mat <eT> &output) {
+      // The DropConnect mask will not be multiplied in the deterministic mode
+      // (during testing).
+      if(uselayer) {
+        if (deterministic)
+        {
+          baseLayer.Forward(input, output);
+        }
+        else {
+          // Scale with input / (1 - ratio) and set values to zero with probability
+          // ratio.
+          mask = arma::randu < arma::Mat <eT> > (baseLayer.Weights().n_rows, baseLayer.Weights().n_cols);
+          mask.transform([&](double val) { return (val > ratio); });
+
+          // Save weights for denoising.
+          denoise = baseLayer.Weights();
+
+          baseLayer.Weights() = baseLayer.Weights() % mask;
+
+          baseLayer.Forward(input, output);
+        }
+      }
+      else{
+        if(deterministic)
+        {
+          output = weights * input;
+        }
+        else {
+          // Scale the input / ( 1 - ratio) and set values to zero with probability ratio
+          mask = arma::randu < arma::Mat <eT> > (weights.n_rows, weights.n_cols);
+          mask.transform([&](double val) { return (val > ratio); });
+
+          // Save weights for denoising.
+          denoise = weights;
+          weights = weights % mask;
+          output = weights * input;
+        }
+
+      }
+      output = output * scale;
+
+    }
+
+    /**
+     * Ordinary feed backward pass of the DropConnect layer.
+     *
+     * @param input The propagated input activation.
+     * @param gy The backpropagated error.
+     * @param g The calculated gradient.
+     */
+    template<typename DataType>
+    void Backward(const DataType & input,
+                  const DataType &gy,
+                  DataType &g)
+    {
+      if(uselayer)
+      {
+        baseLayer.Backward(input, gy, g);
+      }
+      else
+      {
+        g = weights.t() * gy;
+      }
+    }
+
+    /**
+     * Calculate the gradient using the output delta and the input activation.
+     * @param d The calculated error.
+     * @param g The calculated gradient.
+     */
+    template<typename eT, typename GradientDataType>
+    void Gradient(const arma::Mat<eT>& d, GradientDataType& g)
+    {
+      if(uselayer) 
+      {
+        baseLayer.Gradient(d, g);
+
+        // Denoise the weights.
+        baseLayer.Weights() = denoise;
+      }
+      else
+      {
+        g = d * inputParameter.t();
+
+	// Denoise the weights.
+        weights = denoise;
+      }
+    }
+
+    //! Get the weights.
+      OutputDataType const& Weights() const 
+      { 
+	if(uselayer)
+        {
+	  return baseLayer.Weights(); 
+        }
+	else{
+	  return weights;
+	}
+      }
+
+    //! Modify the weights.
+    OutputDataType& Weights() 
+     {
+       if(uselayer)
+       {
+          return baseLayer.Weights();
+       }
+       else{
+	 return weights;
+       }
+     }
+    
+    //! Get the input parameter.
+    InputDataType &InputParameter() const 
+    {
+      if(uselayer)
+      {
+	return baseLayer.InputParameter();
+      }
+      else
+      {
+	  return inputParameter;
+      }
+    }
+
+    //! Modify the input parameter.
+    InputDataType &InputParameter() 
+    {
+       if(uselayer)
+      {
+	return baseLayer.InputParameter();
+      }
+      else
+      {
+	  return inputParameter;
+      }
+    }
+
+    //! Get the output parameter.
+    OutputDataType &OutputParameter() const 
+    {
+      if(uselayer)
+      {
+	return baseLayer.OutputParameter();
+      }
+      else
+      {
+	return outputParameter;
+      }
+    }
+
+    //! Modify the output parameter.
+    OutputDataType &OutputParameter()
+    {
+      if(uselayer)
+      {
+	return baseLayer.OutputParameter();
+      }
+      else
+      {
+	return outputParameter;
+      }
+    }
+    //! Get the delta.
+    OutputDataType const& Delta() const 
+    {
+      if(uselayer)
+      {
+        return baseLayer.Delta();
+      }
+      else
+      {
+        return delta;
+      }
+    }
+
+    //! Modify the delta.
+    OutputDataType& Delta()
+    {
+      if(uselayer)
+      {
+        return baseLayer.Delta();
+      }
+      else
+      {
+        return delta;
+      }
+    }
+    
+     //! Get the gradient.
+     OutputDataType const& Gradient() const
+     { 
+       if(uselayer)
+       {
+	 return baseLayer.Gradient(); 
+       }
+       else
+       {
+	 return gradient;
+       }
+     }
+
+    //! Modify the gradient.
+    OutputDataType& Gradient()
+    {
+       if(uselayer)
+       {
+	 return baseLayer.Gradient(); 
+       }
+       else
+       {
+	 return gradient;
+       }
+    }
+
+    //! Input Layer default value
+    InputLayer initRule(size_t inSize, size_t outSize)
+    {
+      InputLayer newLayer(inSize, outSize);
+      return std::forward<InputLayer>(newLayer);
+    } 
+
+
+    //! The value of the deterministic parameter.
+    bool Deterministic() const { return deterministic; }
+
+    //! Modify the value of the deterministic parameter.
+    bool &Deterministic() { return deterministic; }
+
+    //! The probability of setting a value to zero.
+    double Ratio() const { return ratio; }
+
+    //! Modify the probability of setting a value to zero.
+    void Ratio(const double r) {
+      ratio = r;
+      scale = 1.0 / (1.0 - ratio);
+    }
+    //! Locally stored number of input units.
+    size_t inSize;
+
+    //! Locally-stored number of output units.
+    size_t outSize;
+
+    //! Locally-stored weight object.
+    OutputDataType weights;
+
+    //! Locally-stored delta object.
+    OutputDataType delta;
+
+    //! Locally-stored layer object.
+    InputLayer baseLayer;
+    
+    //! Locally-stored gradient object.
+    OutputDataType gradient;
+
+    //! Locally-stored input parameter object.
+    InputDataType inputParameter;
+
+    //! Locally-stored output parameter object.
+    OutputDataType outputParameter;
+
+    //! Locally-stored mast object.
+    OutputDataType mask;
+
+    //! The probability of setting a value to zero.
+    double ratio;
+
+    //! The scale fraction.
+    double scale;
+
+    //! If true dropout and scaling is disabled, see notes above.
+    bool deterministic;
+
+    //! If true the default layer is used otherwise a new layer will be created.
+    bool uselayer;
+    
+    //! Denoise mask for the weights.
+    OutputDataType denoise;
+  }; // class DropConnectLayer.
+//! Layer Traits for the DropConnectLayer
+template <
+  typename InputLayer,
+  typename InputDataType,
+  typename OutputDataType
+>
+class LayerTraits<DropConnectLayer<InputLayer, InputDataType, OutputDataType> >
+{
+ public:
+  static const bool IsBinary = false;
+  static const bool IsOutputLayer = false;
+  static const bool IsBiasLayer = false;
+  static const bool IsLSTMLayer = false;
+  static const bool IsConnection = true;
+};
+
+}  // namespace ann
+}  // namespace mlpack
+#endif
diff --git a/src/mlpack/methods/ann/layer/empty_layer.hpp b/src/mlpack/methods/ann/layer/empty_layer.hpp
new file mode 100644
index 0000000..1450a52
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/empty_layer.hpp
@@ -0,0 +1,84 @@
+/**
+ * @file empty_layer.hpp
+ * @author Palash Ahuja
+ *
+ * Definition of the EmptyLayer class, which is basically empty.
+ */
+#ifndef __MLPACK_METHODS_ANN_LAYER_EMPTY_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_EMPTY_LAYER_HPP
+
+namespace mlpack{
+namespace ann /** Artificial Neural Network. */ {
+/**
+ * Definition of an empty layer class which does absolutely nothing.
+ */
+class EmptyLayer
+{
+  public:
+  /**
+   * Creates the empty layer object. All the methods are
+   * empty as well.
+   */
+  EmptyLayer()
+  {
+    // nothing to do here.
+  }
+  template<typename eT>
+  void Forward(const arma::Mat<eT>&, arma::Mat<eT>&)
+  {
+    // nothing to do here.
+  }
+
+  template<typename InputType, typename eT>
+  void Backward(const InputType&,/* unused */
+                const arma::Mat<eT>&,
+		arma::Mat<eT>&)
+  {
+    // nothing to do here.
+  }
+
+  template<typename eT, typename GradientDataType>
+  void Gradient(const arma::Mat<eT>&, GradientDataType&)
+  {
+    // nothing to do here.
+  }
+
+  //! Get the weights.
+  arma::mat const& Weights() const { return random; }
+  
+  //! Modify the weights.
+  arma::mat& Weights() { return random; }
+  
+  //! Get the input parameter.
+  arma::mat const& InputParameter() const { return random; }
+  
+  //! Modify the input parameter.
+  arma::mat& InputParameter() { return random; }
+
+  //! Get the output parameter.
+  arma::mat const& OutputParameter() const { return random; }
+
+  //! Modify the output parameter.
+  arma::mat& OutputParameter() { return random; }
+
+  //! Get the delta.
+  arma::mat const& Delta() const { return random; }
+  
+  //! Modify the delta.
+  arma::mat& Delta() { return random; }
+
+  //! Get the gradient.
+  arma::mat const& Gradient() const { return random; }
+
+  //! Modify the gradient.
+  arma::mat& Gradient() { return random; } 
+  
+  //! something random.
+  arma::mat random;
+
+}; // class EmptyLayer
+
+} //namespace ann
+} //namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index e8b6636..3efab53 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -16,6 +16,7 @@
 #include <mlpack/methods/ann/layer/base_layer.hpp>
 #include <mlpack/methods/ann/layer/dropout_layer.hpp>
 #include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
+#include <mlpack/methods/ann/layer/dropconnect_layer.hpp>
 
 #include <mlpack/methods/ann/ffn.hpp>
 #include <mlpack/methods/ann/performance_functions/mse_function.hpp>
@@ -286,4 +287,210 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest)
     (dataset, labels, dataset, labels, 8, 30, 0.4);
 }
 
+/**
+ * Train and evaluate a DropConnect network(with a baselayer) with the specified structure.
+ */
+ template<
+            typename PerformanceFunction,
+            typename OutputLayerType,
+            typename PerformanceFunctionType,
+            typename MatType = arma::mat
+    >
+    void BuildDropConnectNetwork(MatType& trainData,
+                                 MatType& trainLabels,
+                                 MatType& testData,
+                                 MatType& testLabels,
+                                 const size_t hiddenLayerSize,
+                                 const size_t maxEpochs,
+                                 const double classificationErrorThreshold) {
+      /*
+       * Construct a feed forward network with trainData.n_rows input nodes,
+       * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
+       * network struct that looks like:
+       *
+       * Input         Hidden       DropConnect     Output
+       *  Layer         Layer         Layer        Layer
+       * +-----+       +-----+       +-----+       +-----+
+       * |     |       |     |       |     |       |     |
+       * |     +------>|     +------>|     +------>|     |
+       * |     |     +>|     |       |     |       |     |
+       * +-----+     | +--+--+       +-----+       +-----+
+       *             |
+       *  Bias       |
+       *  Layer      |
+       * +-----+     |
+       * |     |     |
+       * |     +-----+
+       * |     |
+       * +-----+
+       *
+       *
+       */
+      LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
+      BiasLayer<> biasLayer(hiddenLayerSize);
+      BaseLayer<PerformanceFunction> hiddenLayer0;
+
+      LinearLayer<> hiddenLayer1(hiddenLayerSize, trainLabels.n_rows);
+      DropConnectLayer<decltype(hiddenLayer1)> dropConnectLayer0(hiddenLayer1);
+
+      BaseLayer<PerformanceFunction> outputLayer;
+
+      OutputLayerType classOutputLayer;
+
+      auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0,
+                              dropConnectLayer0, outputLayer);
+
+      FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+              PerformanceFunctionType> net(modules, classOutputLayer);
+      RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+                                 maxEpochs * trainData.n_cols, 1e-18);
+      net.Train(trainData, trainLabels, opt);
+      MatType prediction;
+      net.Predict(testData, prediction);
+
+      size_t error = 0;
+      for (size_t i = 0; i < testData.n_cols; i++) {
+        if (arma::sum(arma::sum(
+                arma::abs(prediction.col(i) - testLabels.col(i)))) == 0) {
+          error++;
+        }
+      }
+      double classificationError = 1 - double(error) / testData.n_cols;
+      BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
+}   
+    
+/**
+ * Train and evaluate a DropConnect network(with a linearlayer) with the specified structure.
+ */
+    template<
+            typename PerformanceFunction,
+            typename OutputLayerType,
+            typename PerformanceFunctionType,
+            typename MatType = arma::mat
+    >
+    void BuildDropConnectNetworkLinear(MatType& trainData,
+                                 MatType& trainLabels,
+                                 MatType& testData,
+                                 MatType& testLabels,
+                                 const size_t hiddenLayerSize,
+                                 const size_t maxEpochs,
+                                 const double classificationErrorThreshold) {
+      /*
+       * Construct a feed forward network with trainData.n_rows input nodes,
+       * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
+       * network struct that looks like:
+       *
+       * Input         Hidden       DropConnect     Output
+       *  Layer         Layer         Layer        Layer
+       * +-----+       +-----+       +-----+       +-----+
+       * |     |       |     |       |     |       |     |
+       * |     +------>|     +------>|     +------>|     |
+       * |     |     +>|     |       |     |       |     |
+       * +-----+     | +--+--+       +-----+       +-----+
+       *             |
+       *  Bias       |
+       *  Layer      |
+       * +-----+     |
+       * |     |     |
+       * |     +-----+
+       * |     |
+       * +-----+
+       *
+       *
+       */
+      LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
+      BiasLayer<> biasLayer(hiddenLayerSize);
+      BaseLayer<PerformanceFunction> hiddenLayer0;
+      const size_t number_of_rows = trainLabels.n_rows;
+      DropConnectLayer<> dropConnectLayer0(hiddenLayerSize, number_of_rows);
+
+      BaseLayer<PerformanceFunction> outputLayer;
+
+      OutputLayerType classOutputLayer;
+      auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0,
+                              dropConnectLayer0, outputLayer);
+
+      FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+              PerformanceFunctionType> net(modules, classOutputLayer);
+      RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+                                 maxEpochs * trainData.n_cols, 1e-18);
+      net.Train(trainData, trainLabels, opt);
+      MatType prediction;
+      net.Predict(testData, prediction);
+
+      size_t error = 0;
+      for (size_t i = 0; i < testData.n_cols; i++) {
+        if (arma::sum(arma::sum(
+                arma::abs(prediction.col(i) - testLabels.col(i)))) == 0) {
+          error++;
+        }
+      }
+      double classificationError = 1 - double(error) / testData.n_cols;
+      BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
+    }
+/**
+ * Train the dropconnect network on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(DropConnectNetworkTest)
+{
+  // Load the dataset.
+  arma::mat dataset;
+  data::Load("thyroid_train.csv", dataset, true);
+
+  arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
+      dataset.n_cols - 1);
+  arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
+      dataset.n_rows - 1, dataset.n_cols - 1);
+
+  data::Load("thyroid_test.csv", dataset, true);
+
+  arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
+      dataset.n_cols - 1);
+  arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
+      dataset.n_rows - 1, dataset.n_cols - 1);
+
+  // Vanilla neural net with logistic activation function.
+  // Because 92 percent of the patients are not hyperthyroid the neural
+  // network must be significant better than 92%.
+  BuildDropConnectNetwork<LogisticFunction,
+                          BinaryClassificationLayer,
+                          MeanSquaredErrorFunction>
+      (trainData, trainLabels, testData, testLabels, 4, 100, 0.1);
+
+  BuildDropConnectNetworkLinear<LogisticFunction,
+                          BinaryClassificationLayer,
+                          MeanSquaredErrorFunction>
+      (trainData, trainLabels, testData, testLabels, 4, 100, 0.1);
+
+  dataset.load("mnist_first250_training_4s_and_9s.arm");
+
+  // Normalize each point since these are images.
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+    dataset.col(i) /= norm(dataset.col(i), 2);
+
+  arma::mat labels = arma::zeros(1, dataset.n_cols);
+  labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
+
+  // Vanilla neural net with logistic activation function.
+  BuildDropConnectNetwork<LogisticFunction,
+                      BinaryClassificationLayer,
+                      MeanSquaredErrorFunction>
+      (dataset, labels, dataset, labels, 8, 30, 0.4);
+
+
+  BuildDropConnectNetworkLinear<LogisticFunction,
+                      BinaryClassificationLayer,
+                      MeanSquaredErrorFunction>
+      (dataset, labels, dataset, labels, 8, 30, 0.4);
+
+  // Vanilla neural net with tanh activation function.
+  BuildDropConnectNetwork<TanhFunction,
+                      BinaryClassificationLayer,
+                      MeanSquaredErrorFunction>
+    (dataset, labels, dataset, labels, 8, 30, 0.4);
+  BuildDropConnectNetworkLinear<TanhFunction,
+                      BinaryClassificationLayer,
+                      MeanSquaredErrorFunction>
+    (dataset, labels, dataset, labels, 8, 30, 0.4);
+}
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list