lstm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LSTM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <limits>
17 
18 namespace mlpack {
19 namespace ann {
20 
54 template <
55  typename InputDataType = arma::mat,
56  typename OutputDataType = arma::mat
57 >
58 class LSTM
59 {
60  public:
62  LSTM();
63 
71  LSTM(const size_t inSize,
72  const size_t outSize,
73  const size_t rho = std::numeric_limits<size_t>::max());
74 
82  template<typename InputType, typename OutputType>
83  void Forward(InputType&& input, OutputType&& output);
84 
94  template<typename InputType, typename OutputType>
95  void Forward(InputType&& input,
96  OutputType&& output,
97  OutputType&& cellState,
98  bool useCellState = false);
99 
109  template<typename InputType, typename ErrorType, typename GradientType>
110  void Backward(const InputType&& input,
111  ErrorType&& gy,
112  GradientType&& g);
113 
114  /*
115  * Reset the layer parameter.
116  */
117  void Reset();
118 
119  /*
120  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
121  * new one.
122  *
123  * @param size The current maximum number of steps through time.
124  */
125  void ResetCell(const size_t size);
126 
127  /*
128  * Calculate the gradient using the output delta and the input activation.
129  *
130  * @param input The input parameter used for calculating the gradient.
131  * @param error The calculated error.
132  * @param gradient The calculated gradient.
133  */
134  template<typename InputType, typename ErrorType, typename GradientType>
135  void Gradient(InputType&& input,
136  ErrorType&& error,
137  GradientType&& gradient);
138 
140  size_t Rho() const { return rho; }
142  size_t& Rho() { return rho; }
143 
145  OutputDataType const& Parameters() const { return weights; }
147  OutputDataType& Parameters() { return weights; }
148 
150  OutputDataType const& OutputParameter() const { return outputParameter; }
152  OutputDataType& OutputParameter() { return outputParameter; }
153 
155  OutputDataType const& Delta() const { return delta; }
157  OutputDataType& Delta() { return delta; }
158 
160  OutputDataType const& Gradient() const { return grad; }
162  OutputDataType& Gradient() { return grad; }
163 
167  template<typename Archive>
168  void serialize(Archive& ar, const unsigned int /* version */);
169 
170  private:
172  size_t inSize;
173 
175  size_t outSize;
176 
178  size_t rho;
179 
181  size_t forwardStep;
182 
184  size_t backwardStep;
185 
187  size_t gradientStep;
188 
190  OutputDataType weights;
191 
193  OutputDataType prevOutput;
194 
196  size_t batchSize;
197 
199  size_t batchStep;
200 
203  size_t gradientStepIdx;
204 
206  OutputDataType cellActivationError;
207 
209  OutputDataType delta;
210 
212  OutputDataType grad;
213 
215  OutputDataType outputParameter;
216 
218  OutputDataType output2GateInputWeight;
219 
221  OutputDataType input2GateInputWeight;
222 
224  OutputDataType input2GateInputBias;
225 
227  OutputDataType cell2GateInputWeight;
228 
230  OutputDataType output2GateForgetWeight;
231 
233  OutputDataType input2GateForgetWeight;
234 
236  OutputDataType input2GateForgetBias;
237 
239  OutputDataType cell2GateForgetWeight;
240 
242  OutputDataType output2GateOutputWeight;
243 
245  OutputDataType input2GateOutputWeight;
246 
248  OutputDataType input2GateOutputBias;
249 
251  OutputDataType cell2GateOutputWeight;
252 
254  OutputDataType inputGate;
255 
257  OutputDataType forgetGate;
258 
260  OutputDataType hiddenLayer;
261 
263  OutputDataType outputGate;
264 
266  OutputDataType inputGateActivation;
267 
269  OutputDataType forgetGateActivation;
270 
272  OutputDataType outputGateActivation;
273 
275  OutputDataType hiddenLayerActivation;
276 
278  OutputDataType input2HiddenWeight;
279 
281  OutputDataType input2HiddenBias;
282 
284  OutputDataType output2HiddenWeight;
285 
287  OutputDataType cell;
288 
290  OutputDataType cellActivation;
291 
293  OutputDataType forgetGateError;
294 
296  OutputDataType outputGateError;
297 
299  OutputDataType prevError;
300 
302  OutputDataType outParameter;
303 
305  OutputDataType inputCellError;
306 
308  OutputDataType inputGateError;
309 
311  OutputDataType hiddenError;
312 
314  size_t rhoSize;
315 
317  size_t bpttSteps;
318 }; // class LSTM
319 
320 } // namespace ann
321 } // namespace mlpack
322 
323 // Include implementation.
324 #include "lstm_impl.hpp"
325 
326 #endif
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: lstm.hpp:150
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & Gradient()
Modify the gradient.
Definition: lstm.hpp:162
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: lstm.hpp:152
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: lstm.hpp:140
OutputDataType const & Parameters() const
Get the parameters.
Definition: lstm.hpp:145
void Forward(InputType &&input, OutputType &&output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Delta() const
Get the delta.
Definition: lstm.hpp:155
void Backward(const InputType &&input, ErrorType &&gy, GradientType &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
LSTM()
Create the LSTM object.
OutputDataType const & Gradient() const
Get the gradient.
Definition: lstm.hpp:160
void ResetCell(const size_t size)
OutputDataType & Delta()
Modify the delta.
Definition: lstm.hpp:157
OutputDataType & Parameters()
Modify the parameters.
Definition: lstm.hpp:147
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: lstm.hpp:142