[mlpack-git] [mlpack] ANN Saving the network and reloading (#531)

Marcus Edel notifications at github.com
Wed Mar 2 10:33:28 EST 2016


Btw. I used the following code to test:

```
auto GetVanillaNetwork(arma::mat& trainData, arma::mat& trainLabels)
{
  int hiddenLayerSize = 10;

  LinearLayer<> inputLayer(trainData.n_rows, hiddenLayerSize);
  BiasLayer<> inputBiasLayer(hiddenLayerSize);
  BaseLayer<> inputBaseLayer;

  LinearLayer<> hiddenLayer1(hiddenLayerSize, trainLabels.n_rows);
  BiasLayer<> hiddenBiasLayer1(trainLabels.n_rows);
  BaseLayer<> outputLayer;

  BinaryClassificationLayer classOutputLayer;

  auto modules = std::make_tuple(inputLayer, inputBiasLayer, inputBaseLayer,
                          hiddenLayer1, hiddenBiasLayer1, outputLayer);

  FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
      MeanSquaredErrorFunction> net(modules, classOutputLayer);

  RMSprop<decltype(net)> opt(net, 0.01, 0.88, 1e-8,
      20 * trainData.n_cols, 1e-18);

  return net;
}

BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
{
  // Load the dataset.
  arma::mat dataset;
  data::Load("thyroid_train.csv", dataset, true);

  arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
      dataset.n_cols - 1);
  arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
      dataset.n_rows - 1, dataset.n_cols - 1);

  data::Load("thyroid_test.csv", dataset, true);

  arma::mat testData = dataset.submat(0, 0, dataset.n_rows - 4,
      dataset.n_cols - 1);
  arma::mat testLabels = dataset.submat(dataset.n_rows - 3, 0,
      dataset.n_rows - 1, dataset.n_cols - 1);


  auto net = GetVanillaNetwork(trainData, trainLabels);
  arma::mat prediction;
  net.Predict(dataset, prediction);
}
```

---
Reply to this email directly or view it on GitHub:
https://github.com/mlpack/mlpack/issues/531#issuecomment-191289070
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mailman.cc.gatech.edu/pipermail/mlpack-git/attachments/20160302/57a05e85/attachment.html>


More information about the mlpack-git mailing list