12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_HPP 13 #define MLPACK_METHODS_ANN_LAYER_LSTM_HPP 59 typename InputDataType = arma::mat,
60 typename OutputDataType = arma::mat
75 LSTM(
const size_t inSize,
77 const size_t rho = std::numeric_limits<size_t>::max());
86 template<
typename InputType,
typename OutputType>
87 void Forward(InputType&& input, OutputType&& output);
98 template<
typename InputType,
typename OutputType>
101 OutputType&& cellState,
102 bool useCellState =
false);
113 template<
typename InputType,
typename ErrorType,
typename GradientType>
114 void Backward(
const InputType&& input,
138 template<
typename InputType,
typename ErrorType,
typename GradientType>
141 GradientType&& gradient);
144 size_t Rho()
const {
return rho; }
146 size_t&
Rho() {
return rho; }
159 OutputDataType
const&
Delta()
const {
return delta; }
161 OutputDataType&
Delta() {
return delta; }
164 OutputDataType
const&
Gradient()
const {
return grad; }
171 template<
typename Archive>
172 void serialize(Archive& ar,
const unsigned int );
194 OutputDataType weights;
197 OutputDataType prevOutput;
207 size_t gradientStepIdx;
210 OutputDataType cellActivationError;
213 OutputDataType delta;
219 OutputDataType outputParameter;
222 OutputDataType output2GateInputWeight;
225 OutputDataType input2GateInputWeight;
228 OutputDataType input2GateInputBias;
231 OutputDataType cell2GateInputWeight;
234 OutputDataType output2GateForgetWeight;
237 OutputDataType input2GateForgetWeight;
240 OutputDataType input2GateForgetBias;
243 OutputDataType cell2GateForgetWeight;
246 OutputDataType output2GateOutputWeight;
249 OutputDataType input2GateOutputWeight;
252 OutputDataType input2GateOutputBias;
255 OutputDataType cell2GateOutputWeight;
258 OutputDataType inputGate;
261 OutputDataType forgetGate;
264 OutputDataType hiddenLayer;
267 OutputDataType outputGate;
270 OutputDataType inputGateActivation;
273 OutputDataType forgetGateActivation;
276 OutputDataType outputGateActivation;
279 OutputDataType hiddenLayerActivation;
282 OutputDataType input2HiddenWeight;
285 OutputDataType input2HiddenBias;
288 OutputDataType output2HiddenWeight;
294 OutputDataType cellActivation;
297 OutputDataType forgetGateError;
300 OutputDataType outputGateError;
303 OutputDataType prevError;
306 OutputDataType outParameter;
309 OutputDataType inputCellError;
312 OutputDataType inputGateError;
315 OutputDataType hiddenError;
328 #include "lstm_impl.hpp" OutputDataType const & OutputParameter() const
Get the output parameter.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & Gradient()
Modify the gradient.
OutputDataType & OutputParameter()
Modify the output parameter.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
OutputDataType const & Parameters() const
Get the parameters.
void Forward(InputType &&input, OutputType &&output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Delta() const
Get the delta.
void Backward(const InputType &&input, ErrorType &&gy, GradientType &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
LSTM()
Create the LSTM object.
OutputDataType const & Gradient() const
Get the gradient.
void ResetCell(const size_t size)
OutputDataType & Delta()
Modify the delta.
OutputDataType & Parameters()
Modify the parameters.
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).