[mlpack-git] [mlpack] improve speed of SparseAutoencoder and make it more flexible (#451)

stereomatchingkiss notifications at github.com
Sat Dec 5 01:54:29 EST 2015


>Opening another pull request to discuss the implementation details sounds great.

Then I would open one after I finished autoencoder(if no one open it yet)

>but instead just using the pre-existing SparseAutoencoder typedef

I think this is a good idea and prefer to do it this way

By the way, this is the layers(example 1) I try to feed into the FNN, do you have any suggestions?

    const arma::mat trainData = arma::randu<arma::mat>(4,4);
    const size_t visibleSize = trainData.n_rows;
    const size_t hiddenSize = trainData.n_rows / 2;
    const double range = sqrt(6) / sqrt(visibleSize + hiddenSize + 1);

    LinearLayer<RMSPROP, RandomInitialization>
            hiddenLayer(visibleSize, hiddenSize, {-range, range});
    BiasLayer<RMSPROP, RandomInitialization> hiddenBiasLayer(hiddenSize);
    BaseLayer<LogisticFunction> hiddenBaseLayer;

    LinearLayer<RMSPROP, RandomInitialization>
            outputLayer(hiddenSize, visibleSize, {-range, range});
    BiasLayer<RMSPROP, RandomInitialization> outputBiasLayer(hiddenSize);
    BaseLayer<LogisticFunction> outputBaseLayer;

    auto network = std::tie(hiddenLayer, hiddenBiasLayer, hiddenBaseLayer,
                            outputLayer, outputBiasLayer, outputBaseLayer);
    auto resultLayer = OneHotLayer();
    FFN<decltype(network), OneHotLayer, SparseErrorFunction>
            ffn(network, resultLayer);

>We have to implement a new performance function to calculate the reconstruction error

This one may need to change the implementation details of the ann module.

There are two problems need to be solved

1 : The performance function of ann cannot store the condition by current implementation
2 : The performance function cannot access w1 and w2 of the network

    template<typename DataType, typename ErrorType, typename... Tp>
    double OutputError(const DataType& target,
                       ErrorType& error,
                       const std::tuple<Tp...>& t)
    {
      // Calculate and store the output error.
      outputLayer.CalculateError(
          std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target, error);

      // Masures the network's performance with the specified performance
      // function.
      return PerformanceFunction::Error(
          std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target);
    }

As you see, what I could get from the Error api is the target(in autoencoder, it is the original training data) and the activation of the output layer(please refer to example 1), but not the w1 and w2 of the hidden layer and the output layer.But we need them to calculate the weightDecay. Besides, I guess the users may want to setup the weight of KL divergence and weightDecay too.

If we could store the PerformanceFunction, this problem could be solved.Rather than calling static Error function, call it like 

    performanceFunction.Error(std::get<sizeof...(Tp) - 1>(t).OutputParameter(), target);

change the constructor to

    FFN(const LayerTypes& network, OutputLayerType& outputLayer, 
          PerformanceFunction performanceFunction = CrossEntropyErrorFunction())
        : network(network), outputLayer(outputLayer), trainError(0),
    performance(std::forward<PerformanceFunction>(performanceFunction))
      {
        // Nothing to do here.
      }


And call it like

    const double beta = 0.3; //weight of KL divergence
    const double lambda = 0.001; //weight of weight decay
    FFN<decltype(network), OneHotLayer, SparseErrorFunction>
    ffn(network, resultLayer, 
        SparseErrorFunction(hiddenLayer.Weights(), outputLayer.Weights(), beta, lambda));


I think this should be more flexible. The other question is, do you have any plans to provide move constructor and move assignments for ann modules? Most of the class would not be needed to change anything if armadillo already implement it, maybe a better way is ask armadillo implement one, have no idea why they do not provide it.


---
Reply to this email directly or view it on GitHub:
https://github.com/mlpack/mlpack/pull/451#issuecomment-162156786
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mailman.cc.gatech.edu/pipermail/mlpack-git/attachments/20151204/526c5802/attachment.html>


More information about the mlpack-git mailing list