<p>I would like to develop a scalable Finetune class, need some suggestions, following are the suggestion api</p>
<pre><code>/**
* Fine tune deep network like StackAutoencoder
*
*@tparam LayerTypes types of the layers, must provide three functions.
* - Gradient(const arma::mat&, arma::mat&);
* - double Evaluate(const arma::mat&);
* - arma::mat& GetInitialPoint();
* You can reference to SparseAutoencoderFunction
*@tparam OutputLayerType types of the output layer, must implement three functions.
* Gradient(const arma::mat& parameters, arma::mat& gradient);
* double Evaluate(const arma::mat& parameters);
* arma::mat& GetInitialPoint();
*@tparam FineTuneGradient Functor for calculating the last gradient, it should implement two functions
* - template<typename T> Gradient(arma::mat const&, arma::mat const&, T const&, arma::mat&)
* - Deriv(arma::mat const&, arma::mat&);
*/
template<typename LayerTypes, typename OutputLayerType,
typename FineTuneGradient>
class FineTuneFunction
{
public:
using ParamArray =
std::array<arma::mat*, std::tuple_size<LayerTypes>::value + 1>;
static_assert(std::tuple_size<LayerTypes>::value > 1,
"The tuple size of the LayerTypes must greater than 1");
/**
* Construct the class with given data
* @param input The input data of the LayerTypes and OutputLayerType
* @param parameters The parameters of the LayerTypes and OutputLayerType
* @param layerTypes The type(must be tuple) of the Layer(by now only support SparseAutoencoder)
* @param outLayerType The type of the last layer(ex : softmax)
*/
FineTuneFunction(ParamArray &input,
ParamArray &parameters,
LayerTypes &layerTypes,
OutputLayerType &outLayerType)
: trainData(input),
paramArray(parameters),
layerTypes(layerTypes),
outLayerType(outLayerType),
LayerTypesParamSize(LayerParamTotalSize<>())
{
}
/**
* Evaluates the objective function of the networks using the
* given parameters.
* @param parameters Current values of the model parameters.
*/
double Evaluate(const arma::mat& parameters);
/**
* Evaluates the gradient values of the objective function given the current
* set of parameters. The function performs a feedforward pass and computes
* the error in reconstructing the data points. It then uses the
* backpropagation algorithm to compute the gradient values.
* @param parameters Current values of the model parameters.
* @param gradient Matrix where gradient values will be stored.
*/
void Gradient(const arma::mat& parameters, arma::mat& gradient);
//! Return the initial point for the optimization.
arma::mat& GetInitialPoint();
};
</code></pre>
<p>The example of using this class(omit the initialization part of the training data)</p>
<pre><code>using namespace mlpack;
arma::mat sae1_input ;
arma::mat sae2_input ;
nn::SparseAutoencoderFunction<> sae1(sae1_input, 3, 2);
nn::SparseAutoencoderFunction<> sae2(sae2_input, 2, 2);
arma::mat sm_input;
arma::Row<size_t> labels(2);
labels(0) = 0;
labels(1) = 1;
regression::SoftmaxRegressionFunction sm(sm_input, labels, 2);
arma::mat sae1_params;
arma::mat sae2_params;
arma::mat sm_params;
//after training, the class will update the params
std::array<arma::mat*, 3> params{
&sae1_params, &sae2_params, &sm_params
};
//the class will change the input(except the first input) when training
std::array<arma::mat*, 3> inputs(&sae1_input, &sae2_input, &sm_input);
auto layer_types = std::forward_as_tuple(sae1, sae2);
FineTuneFunction<
decltype(layer_types),
decltype(sm),
SoftmaxFineTune
> finetune(inputs, params, layer_types, sm);
//create lbfgs to fine tune the value
</code></pre>
<p>Besides, the SoftmaxFunction need to cache the probabilities value, else you have to recalculate the probabilites two more times when fine tune the parameters, any suggestions?Thanks</p>
<p style="font-size:small;-webkit-text-size-adjust:none;color:#666;">—<br>Reply to this email directly or <a href="https://github.com/mlpack/mlpack/issues/458">view it on GitHub</a>.<img alt="" height="1" src="https://github.com/notifications/beacon/AJ4bFH10GWBaLd_nHXOKhNPCeuIfKGMiks5o4N8KgaJpZM4GIjrs.gif" width="1" /></p>
<div itemscope itemtype="http://schema.org/EmailMessage">
<div itemprop="action" itemscope itemtype="http://schema.org/ViewAction">
<link itemprop="url" href="https://github.com/mlpack/mlpack/issues/458"></link>
<meta itemprop="name" content="View Issue"></meta>
</div>
<meta itemprop="description" content="View this Issue on GitHub"></meta>
</div>