[mlpack-git] master: Adding complete implementation of dropconnect layer with all the suggested changes. (63a7f62)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Mar 23 09:47:24 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7199297dd05a1a8dbc6525bdd7fcd13559596e6b...11b4b5e99199a2f360eba220ed0abe183fdae410
>---------------------------------------------------------------
commit 63a7f623d799f40e4743f04d7043928c582eb1e2
Author: palashahuja <abhor902 at gmail.com>
Date: Wed Mar 16 13:03:12 2016 +0530
Adding complete implementation of dropconnect layer
with all the suggested changes.
>---------------------------------------------------------------
63a7f623d799f40e4743f04d7043928c582eb1e2
src/mlpack/methods/ann/layer/CMakeLists.txt | 2 +
src/mlpack/methods/ann/layer/dropconnect_layer.hpp | 404 +++++++++++++++++++++
src/mlpack/methods/ann/layer/empty_layer.hpp | 84 +++++
src/mlpack/tests/feedforward_network_test.cpp | 207 +++++++++++
4 files changed, 697 insertions(+)
diff --git a/src/mlpack/methods/ann/layer/CMakeLists.txt b/src/mlpack/methods/ann/layer/CMakeLists.txt
index 04e41df..b639cda 100644
--- a/src/mlpack/methods/ann/layer/CMakeLists.txt
+++ b/src/mlpack/methods/ann/layer/CMakeLists.txt
@@ -4,8 +4,10 @@ set(SOURCES
layer_traits.hpp
binary_classification_layer.hpp
base_layer.hpp
+ empty_layer.hpp
bias_layer.hpp
dropout_layer.hpp
+ dropconnect_layer.hpp
hard_tanh_layer.hpp
leaky_relu_layer.hpp
linear_layer.hpp
diff --git a/src/mlpack/methods/ann/layer/dropconnect_layer.hpp b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
new file mode 100644
index 0000000..01175e9
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/dropconnect_layer.hpp
@@ -0,0 +1,404 @@
+/**
+ * @file dropconnect_layer.hpp
+ * @author Palash Ahuja
+ *
+ * Definition of the DropConnectLayer class, which implements a regularizer
+ * that randomly sets connections to zero. Preventing units from co-adapting.
+ */
+#include "empty_layer.hpp"
+#ifndef __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_DROPCONNECT_LAYER_HPP
+
+namespace mlpack {
+namespace ann/** Artificial Neural Network. */ {
+ /**
+ * The DropConnect layer is a regularizer that randomly with probability
+ * ratio sets the connection values to zero and scales the remaining
+ * elements by factor 1 /(1 - ratio). The output is scaled with 1 / (1 - p)
+ * when deterministic is false. In the deterministic mode(during testing),
+ * the layer just computes the output. The output is computed according
+ * to the input layer. If no input layer is given, it will take a
+ * linear layer as default.
+ *
+ * Note:
+ * During training you should set deterministic to false and during
+ * testing you should set deterministic to true.
+ *
+ * For more information, see the following.
+ * @inproceedings{icml2013_wan13,
+ * Publisher = {JMLR Workshop and Conference Proceedings},
+ * Title = {Regularization of Neural Networks using DropConnect},
+ * Url = {http: // jmlr.org / proceedings / papers / v28 / wan13.pdf},
+ * Booktitle = {Proceedings of the 30th International Conference on Machine
+ * Learning(ICML - 13)},
+ * Author = {Li Wan and Matthew Zeiler and Sixin Zhang and Yann L. Cun and
+ * Rob Fergus},
+ * Number = {3},
+ * Month = may,
+ * Volume = {28},
+ * Editor = {Sanjoy Dasgupta and David Mcallester},
+ * Year = {2013},
+ * Pages = {1058 - 1066},
+ * Abstract = {We introduce DropConnect, a generalization of DropOut, for
+ * regularizing large fully - connected layers within neural networks.When
+ * training with Dropout, a randomly selected subset of activations are set
+ * to zero within each layer. DropConnect instead sets a randomly selected
+ * subset of weights within the network to zero. Each unit thus receives
+ * input from a random subset of units in the previous layer. We derive a
+ * bound on the generalization performance of both Dropout and DropConnect.
+ * We then evaluate DropConnect on a range of datasets, comparing to Dropout,
+ * and show state - of - the - art results on several image recoginition
+ * benchmarks can be obtained by aggregating multiple DropConnect -
+ * trained models.}
+*}
+*/
+
+template<
+ typename InputLayer = EmptyLayer,
+ typename InputDataType = arma::mat,
+ typename OutputDataType = arma::mat
+>
+ class DropConnectLayer {
+ public:
+ /**
+ * Creates the DropConnect Layer as a Linear Object that takes input size and
+ * output size parameter.
+ *
+ * @param inSize The number of input units.
+ * @param outSize The number of output units.
+ */
+ DropConnectLayer (const size_t inSize, const size_t outSize,
+ const double ratio = 0.5):
+ inSize(inSize),
+ outSize(outSize)
+ {
+ scale = 1.0/(1.0 - ratio);
+ uselayer = false;
+ weights.set_size(outSize, inSize);
+ }
+
+ /**
+ * Create the DropConnectLayer object using the specified ratio and rescale
+ * parameter. This takes the
+ *
+ * @param ratio The probability of setting a connection to zero.
+ * @param inputLayer the layer object that the dropconnect connection would take.
+ */
+ template<typename InputLayerType>
+ DropConnectLayer(InputLayerType &&inputLayer,
+ const double ratio = 0.5) :
+ baseLayer(std::forward<InputLayerType>(inputLayer)),
+ ratio(ratio),
+ scale(1.0/(1 - ratio)),
+ uselayer(true)
+ {
+ static_assert(std::is_same<typename std::decay<InputLayerType>::type,
+ InputLayer>::value,
+ "The type of network must be LayerType");
+ }
+ /**
+ * Ordinary feed forward pass of the DropConnect layer.
+ *
+ * @param input Input data used for evaluating the specified function.
+ * @param output Resulting output activation.
+ */
+ template<typename eT>
+ void Forward(const arma::Mat <eT> &input, arma::Mat <eT> &output) {
+ // The DropConnect mask will not be multiplied in the deterministic mode
+ // (during testing).
+ if(uselayer) {
+ if (deterministic)
+ {
+ baseLayer.Forward(input, output);
+ }
+ else {
+ // Scale with input / (1 - ratio) and set values to zero with probability
+ // ratio.
+ mask = arma::randu < arma::Mat <eT> > (baseLayer.Weights().n_rows, baseLayer.Weights().n_cols);
+ mask.transform([&](double val) { return (val > ratio); });
+
+ // Save weights for denoising.
+ denoise = baseLayer.Weights();
+
+ baseLayer.Weights() = baseLayer.Weights() % mask;
+
+ baseLayer.Forward(input, output);
+ }
+ }
+ else{
+ if(deterministic)
+ {
+ output = weights * input;
+ }
+ else {
+ // Scale the input / ( 1 - ratio) and set values to zero with probability ratio
+ mask = arma::randu < arma::Mat <eT> > (weights.n_rows, weights.n_cols);
+ mask.transform([&](double val) { return (val > ratio); });
+
+ // Save weights for denoising.
+ denoise = weights;
+ weights = weights % mask;
+ output = weights * input;
+ }
+
+ }
+ output = output * scale;
+
+ }
+
+ /**
+ * Ordinary feed backward pass of the DropConnect layer.
+ *
+ * @param input The propagated input activation.
+ * @param gy The backpropagated error.
+ * @param g The calculated gradient.
+ */
+ template<typename DataType>
+ void Backward(const DataType & input,
+ const DataType &gy,
+ DataType &g)
+ {
+ if(uselayer)
+ {
+ baseLayer.Backward(input, gy, g);
+ }
+ else
+ {
+ g = weights.t() * gy;
+ }
+ }
+
+ /**
+ * Calculate the gradient using the output delta and the input activation.
+ * @param d The calculated error.
+ * @param g The calculated gradient.
+ */
+ template<typename eT, typename GradientDataType>
+ void Gradient(const arma::Mat<eT>& d, GradientDataType& g)
+ {
+ if(uselayer)
+ {
+ baseLayer.Gradient(d, g);
+
+ // Denoise the weights.
+ baseLayer.Weights() = denoise;
+ }
+ else
+ {
+ g = d * inputParameter.t();
+
+ // Denoise the weights.
+ weights = denoise;
+ }
+ }
+
+ //! Get the weights.
+ OutputDataType const& Weights() const
+ {
+ if(uselayer)
+ {
+ return baseLayer.Weights();
+ }
+ else{
+ return weights;
+ }
+ }
+
+ //! Modify the weights.
+ OutputDataType& Weights()
+ {
+ if(uselayer)
+ {
+ return baseLayer.Weights();
+ }
+ else{
+ return weights;
+ }
+ }
+
+ //! Get the input parameter.
+ InputDataType &InputParameter() const
+ {
+ if(uselayer)
+ {
+ return baseLayer.InputParameter();
+ }
+ else
+ {
+ return inputParameter;
+ }
+ }
+
+ //! Modify the input parameter.
+ InputDataType &InputParameter()
+ {
+ if(uselayer)
+ {
+ return baseLayer.InputParameter();
+ }
+ else
+ {
+ return inputParameter;
+ }
+ }
+
+ //! Get the output parameter.
+ OutputDataType &OutputParameter() const
+ {
+ if(uselayer)
+ {
+ return baseLayer.OutputParameter();
+ }
+ else
+ {
+ return outputParameter;
+ }
+ }
+
+ //! Modify the output parameter.
+ OutputDataType &OutputParameter()
+ {
+ if(uselayer)
+ {
+ return baseLayer.OutputParameter();
+ }
+ else
+ {
+ return outputParameter;
+ }
+ }
+ //! Get the delta.
+ OutputDataType const& Delta() const
+ {
+ if(uselayer)
+ {
+ return baseLayer.Delta();
+ }
+ else
+ {
+ return delta;
+ }
+ }
+
+ //! Modify the delta.
+ OutputDataType& Delta()
+ {
+ if(uselayer)
+ {
+ return baseLayer.Delta();
+ }
+ else
+ {
+ return delta;
+ }
+ }
+
+ //! Get the gradient.
+ OutputDataType const& Gradient() const
+ {
+ if(uselayer)
+ {
+ return baseLayer.Gradient();
+ }
+ else
+ {
+ return gradient;
+ }
+ }
+
+ //! Modify the gradient.
+ OutputDataType& Gradient()
+ {
+ if(uselayer)
+ {
+ return baseLayer.Gradient();
+ }
+ else
+ {
+ return gradient;
+ }
+ }
+
+ //! Input Layer default value
+ InputLayer initRule(size_t inSize, size_t outSize)
+ {
+ InputLayer newLayer(inSize, outSize);
+ return std::forward<InputLayer>(newLayer);
+ }
+
+
+ //! The value of the deterministic parameter.
+ bool Deterministic() const { return deterministic; }
+
+ //! Modify the value of the deterministic parameter.
+ bool &Deterministic() { return deterministic; }
+
+ //! The probability of setting a value to zero.
+ double Ratio() const { return ratio; }
+
+ //! Modify the probability of setting a value to zero.
+ void Ratio(const double r) {
+ ratio = r;
+ scale = 1.0 / (1.0 - ratio);
+ }
+ //! Locally stored number of input units.
+ size_t inSize;
+
+ //! Locally-stored number of output units.
+ size_t outSize;
+
+ //! Locally-stored weight object.
+ OutputDataType weights;
+
+ //! Locally-stored delta object.
+ OutputDataType delta;
+
+ //! Locally-stored layer object.
+ InputLayer baseLayer;
+
+ //! Locally-stored gradient object.
+ OutputDataType gradient;
+
+ //! Locally-stored input parameter object.
+ InputDataType inputParameter;
+
+ //! Locally-stored output parameter object.
+ OutputDataType outputParameter;
+
+ //! Locally-stored mast object.
+ OutputDataType mask;
+
+ //! The probability of setting a value to zero.
+ double ratio;
+
+ //! The scale fraction.
+ double scale;
+
+ //! If true dropout and scaling is disabled, see notes above.
+ bool deterministic;
+
+ //! If true the default layer is used otherwise a new layer will be created.
+ bool uselayer;
+
+ //! Denoise mask for the weights.
+ OutputDataType denoise;
+ }; // class DropConnectLayer.
+//! Layer Traits for the DropConnectLayer
+template <
+ typename InputLayer,
+ typename InputDataType,
+ typename OutputDataType
+>
+class LayerTraits<DropConnectLayer<InputLayer, InputDataType, OutputDataType> >
+{
+ public:
+ static const bool IsBinary = false;
+ static const bool IsOutputLayer = false;
+ static const bool IsBiasLayer = false;
+ static const bool IsLSTMLayer = false;
+ static const bool IsConnection = true;
+};
+
+} // namespace ann
+} // namespace mlpack
+#endif
diff --git a/src/mlpack/methods/ann/layer/empty_layer.hpp b/src/mlpack/methods/ann/layer/empty_layer.hpp
new file mode 100644
index 0000000..1450a52
--- /dev/null
+++ b/src/mlpack/methods/ann/layer/empty_layer.hpp
@@ -0,0 +1,84 @@
+/**
+ * @file empty_layer.hpp
+ * @author Palash Ahuja
+ *
+ * Definition of the EmptyLayer class, which is basically empty.
+ */
+#ifndef __MLPACK_METHODS_ANN_LAYER_EMPTY_LAYER_HPP
+#define __MLPACK_METHODS_ANN_LAYER_EMPTY_LAYER_HPP
+
+namespace mlpack{
+namespace ann /** Artificial Neural Network. */ {
+/**
+ * Definition of an empty layer class which does absolutely nothing.
+ */
+class EmptyLayer
+{
+ public:
+ /**
+ * Creates the empty layer object. All the methods are
+ * empty as well.
+ */
+ EmptyLayer()
+ {
+ // nothing to do here.
+ }
+ template<typename eT>
+ void Forward(const arma::Mat<eT>&, arma::Mat<eT>&)
+ {
+ // nothing to do here.
+ }
+
+ template<typename InputType, typename eT>
+ void Backward(const InputType&,/* unused */
+ const arma::Mat<eT>&,
+ arma::Mat<eT>&)
+ {
+ // nothing to do here.
+ }
+
+ template<typename eT, typename GradientDataType>
+ void Gradient(const arma::Mat<eT>&, GradientDataType&)
+ {
+ // nothing to do here.
+ }
+
+ //! Get the weights.
+ arma::mat const& Weights() const { return random; }
+
+ //! Modify the weights.
+ arma::mat& Weights() { return random; }
+
+ //! Get the input parameter.
+ arma::mat const& InputParameter() const { return random; }
+
+ //! Modify the input parameter.
+ arma::mat& InputParameter() { return random; }
+
+ //! Get the output parameter.
+ arma::mat const& OutputParameter() const { return random; }
+
+ //! Modify the output parameter.
+ arma::mat& OutputParameter() { return random; }
+
+ //! Get the delta.
+ arma::mat const& Delta() const { return random; }
+
+ //! Modify the delta.
+ arma::mat& Delta() { return random; }
+
+ //! Get the gradient.
+ arma::mat const& Gradient() const { return random; }
+
+ //! Modify the gradient.
+ arma::mat& Gradient() { return random; }
+
+ //! something random.
+ arma::mat random;
+
+}; // class EmptyLayer
+
+} //namespace ann
+} //namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp
index e8b6636..3efab53 100644
--- a/src/mlpack/tests/feedforward_network_test.cpp
+++ b/src/mlpack/tests/feedforward_network_test.cpp
@@ -16,6 +16,7 @@
#include <mlpack/methods/ann/layer/base_layer.hpp>
#include <mlpack/methods/ann/layer/dropout_layer.hpp>
#include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
+#include <mlpack/methods/ann/layer/dropconnect_layer.hpp>
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
@@ -286,4 +287,210 @@ BOOST_AUTO_TEST_CASE(DropoutNetworkTest)
(dataset, labels, dataset, labels, 8, 30, 0.4);
}
+/**
+ * Train and evaluate a DropConnect network(with a baselayer) with the specified structure.
+ */
+ template<
+ typename PerformanceFunction,
+ typename OutputLayerType,
+ typename PerformanceFunctionType,
+ typename MatType = arma::mat
+ >
+ void BuildDropConnectNetwork(MatType& trainData,
+ MatType& trainLabels,
+ MatType& testData,
+ MatType& testLabels,
+ const size_t hiddenLayerSize,
+ const size_t maxEpochs,
+ const double classificationErrorThreshold) {
+ /*
+ * Construct a feed forward network with trainData.n_rows input nodes,
+ * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
+ * network struct that looks like:
+ *
+ * Input Hidden DropConnect Output
+ * Layer Layer Layer Layer
+ * +-----+ +-----+ +-----+ +-----+
+ * | | | | | | | |
+ * | +------>| +------>| +------>| |
+ * | | +>| | | | | |
+ * +-----+ | +--+--+ +-----+ +-----+
+ * |
+ * Bias |
+ * Layer |
+ * +-----+ |
+ * | | |
+ * | +-----+
+ * | |
+ * +-----+
+ *
+ *
+ */
+ LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
+ BiasLayer<> biasLayer(hiddenLayerSize);
+ BaseLayer<PerformanceFunction> hiddenLayer0;
+
+ LinearLayer<> hiddenLayer1(hiddenLayerSize, trainLabels.n_rows);
+ DropConnectLayer<decltype(hiddenLayer1)> dropConnectLayer0(hiddenLayer1);
+
+ BaseLayer<PerformanceFunction> outputLayer;
+
+ OutputLayerType classOutputLayer;
+
+ auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0,
+ dropConnectLayer0, outputLayer);
+
+ FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+ PerformanceFunctionType> net(modules, classOutputLayer);
+ RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+ maxEpochs * trainData.n_cols, 1e-18);
+ net.Train(trainData, trainLabels, opt);
+ MatType prediction;
+ net.Predict(testData, prediction);
+
+ size_t error = 0;
+ for (size_t i = 0; i < testData.n_cols; i++) {
+ if (arma::sum(arma::sum(
+ arma::abs(prediction.col(i) - testLabels.col(i)))) == 0) {
+ error++;
+ }
+ }
+ double classificationError = 1 - double(error) / testData.n_cols;
+ BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
+}
+
+/**
+ * Train and evaluate a DropConnect network(with a linearlayer) with the specified structure.
+ */
+ template<
+ typename PerformanceFunction,
+ typename OutputLayerType,
+ typename PerformanceFunctionType,
+ typename MatType = arma::mat
+ >
+ void BuildDropConnectNetworkLinear(MatType& trainData,
+ MatType& trainLabels,
+ MatType& testData,
+ MatType& testLabels,
+ const size_t hiddenLayerSize,
+ const size_t maxEpochs,
+ const double classificationErrorThreshold) {
+ /*
+ * Construct a feed forward network with trainData.n_rows input nodes,
+ * hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
+ * network struct that looks like:
+ *
+ * Input Hidden DropConnect Output
+ * Layer Layer Layer Layer
+ * +-----+ +-----+ +-----+ +-----+
+ * | | | | | | | |
+ * | +------>| +------>| +------>| |
+ * | | +>| | | | | |
+ * +-----+ | +--+--+ +-----+ +-----+
+ * |
+ * Bias |
+ * Layer |
+ * +-----+ |
+ * | | |
+ * | +-----+
+ * | |
+ * +-----+
+ *
+ *
+ */
+ LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
+ BiasLayer<> biasLayer(hiddenLayerSize);
+ BaseLayer<PerformanceFunction> hiddenLayer0;
+ const size_t number_of_rows = trainLabels.n_rows;
+ DropConnectLayer<> dropConnectLayer0(hiddenLayerSize, number_of_rows);
+
+ BaseLayer<PerformanceFunction> outputLayer;
+
+ OutputLayerType classOutputLayer;
+ auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0,
+ dropConnectLayer0, outputLayer);
+
+ FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+ PerformanceFunctionType> net(modules, classOutputLayer);
+ RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
+ maxEpochs * trainData.n_cols, 1e-18);
+ net.Train(trainData, trainLabels, opt);
+ MatType prediction;
+ net.Predict(testData, prediction);
+
+ size_t error = 0;
+ for (size_t i = 0; i < testData.n_cols; i++) {
+ if (arma::sum(arma::sum(
+ arma::abs(prediction.col(i) - testLabels.col(i)))) == 0) {
+ error++;
+ }
+ }
+ double classificationError = 1 - double(error) / testData.n_cols;
+ BOOST_REQUIRE_LE(classificationError, classificationErrorThreshold);
+ }
+/**
+ * Train the dropconnect network on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(DropConnectNetworkTest)
+{
+ // Load the dataset.
+ arma::mat dataset;
+ data::Load("thyroid_train.csv", dataset, true);
+
+ arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
+ dataset.n_cols - 1);
+ arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
+ dataset.n_rows - 1, dataset.n_cols - 1);
+
+ data::Load("thyroid_test.csv", dataset, true);
+
+ arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
+ dataset.n_cols - 1);
+ arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
+ dataset.n_rows - 1, dataset.n_cols - 1);
+
+ // Vanilla neural net with logistic activation function.
+ // Because 92 percent of the patients are not hyperthyroid the neural
+ // network must be significant better than 92%.
+ BuildDropConnectNetwork<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (trainData, trainLabels, testData, testLabels, 4, 100, 0.1);
+
+ BuildDropConnectNetworkLinear<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (trainData, trainLabels, testData, testLabels, 4, 100, 0.1);
+
+ dataset.load("mnist_first250_training_4s_and_9s.arm");
+
+ // Normalize each point since these are images.
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ dataset.col(i) /= norm(dataset.col(i), 2);
+
+ arma::mat labels = arma::zeros(1, dataset.n_cols);
+ labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);
+
+ // Vanilla neural net with logistic activation function.
+ BuildDropConnectNetwork<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 8, 30, 0.4);
+
+
+ BuildDropConnectNetworkLinear<LogisticFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 8, 30, 0.4);
+
+ // Vanilla neural net with tanh activation function.
+ BuildDropConnectNetwork<TanhFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 8, 30, 0.4);
+ BuildDropConnectNetworkLinear<TanhFunction,
+ BinaryClassificationLayer,
+ MeanSquaredErrorFunction>
+ (dataset, labels, dataset, labels, 8, 30, 0.4);
+}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list