[mlpack-git] [mlpack] ANN Saving the network and reloading (#531)
Marcus Edel
notifications at github.com
Wed Mar 2 10:33:28 EST 2016
Btw. I used the following code to test:
```
auto GetVanillaNetwork(arma::mat& trainData, arma::mat& trainLabels)
{
int hiddenLayerSize = 10;
LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
BiasLayer<> inputBiasLayer(hiddenLayerSize);
BaseLayer<> inputBaseLayer;
LinearLayer<> hiddenLayer1(hiddenLayerSize, trainLabels.n_rows);
BiasLayer<> hiddenBiasLayer1(trainLabels.n_rows);
BaseLayer<> outputLayer;
BinaryClassificationLayer classOutputLayer;
auto modules = std::make_tuple(inputLayer, inputBiasLayer, inputBaseLayer,
hiddenLayer1, hiddenBiasLayer1, outputLayer);
FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
MeanSquaredErrorFunction> net(modules, classOutputLayer);
RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
20 * trainData.n_cols, 1e-18);
return net;
}
BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
{
// Load the dataset.
arma::mat dataset;
data::Load("thyroid_train.csv", dataset, true);
arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1);
arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);
data::Load("thyroid_test.csv", dataset, true);
arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1);
arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);
auto net = GetVanillaNetwork(trainData, trainLabels);
arma::mat prediction;
net.Predict(dataset, prediction);
}
```
---
Reply to this email directly or view it on GitHub:
https://github.com/mlpack/mlpack/issues/531#issuecomment-191289070
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mailman.cc.gatech.edu/pipermail/mlpack-git/attachments/20160302/57a05e85/attachment.html>
More information about the mlpack-git
mailing list