fast_lstm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <limits>
18 
19 namespace mlpack {
20 namespace ann {
21 
61 template <
62  typename InputDataType = arma::mat,
63  typename OutputDataType = arma::mat
64 >
65 class FastLSTM
66 {
67  public:
68  // Convenience typedefs.
69  typedef typename InputDataType::elem_type InputElemType;
70  typedef typename OutputDataType::elem_type ElemType;
71 
73  FastLSTM();
74 
82  FastLSTM(const size_t inSize,
83  const size_t outSize,
84  const size_t rho = std::numeric_limits<size_t>::max());
85 
93  template<typename InputType, typename OutputType>
94  void Forward(InputType&& input, OutputType&& output);
95 
105  template<typename InputType, typename ErrorType, typename GradientType>
106  void Backward(const InputType&& input,
107  ErrorType&& gy,
108  GradientType&& g);
109 
110  /*
111  * Reset the layer parameter.
112  */
113  void Reset();
114 
115  /*
116  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
117  * new one.
118  *
119  * @param size The current maximum number of steps through time.
120  */
121  void ResetCell(const size_t size);
122 
123  /*
124  * Calculate the gradient using the output delta and the input activation.
125  *
126  * @param input The input parameter used for calculating the gradient.
127  * @param error The calculated error.
128  * @param gradient The calculated gradient.
129  */
130  template<typename InputType, typename ErrorType, typename GradientType>
131  void Gradient(InputType&& input,
132  ErrorType&& error,
133  GradientType&& gradient);
134 
136  size_t Rho() const { return rho; }
138  size_t& Rho() { return rho; }
139 
141  OutputDataType const& Parameters() const { return weights; }
143  OutputDataType& Parameters() { return weights; }
144 
146  OutputDataType const& OutputParameter() const { return outputParameter; }
148  OutputDataType& OutputParameter() { return outputParameter; }
149 
151  OutputDataType const& Delta() const { return delta; }
153  OutputDataType& Delta() { return delta; }
154 
156  OutputDataType const& Gradient() const { return grad; }
158  OutputDataType& Gradient() { return grad; }
159 
163  template<typename Archive>
164  void serialize(Archive& ar, const unsigned int /* version */);
165 
166  private:
173  template<typename InputType, typename OutputType>
174  void FastSigmoid(InputType&& input, OutputType&& sigmoids)
175  {
176  for (size_t i = 0; i < input.n_elem; ++i)
177  sigmoids(i) = FastSigmoid(input(i));
178  }
179 
186  ElemType FastSigmoid(const InputElemType data)
187  {
188  ElemType x = 0.5 * data;
189  ElemType z;
190  if (x >= 0)
191  {
192  if (x < 1.7)
193  z = (1.5 * x / (1 + x));
194  else if (x < 3)
195  z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7));
196  else
197  z = 0.99505475368673;
198  }
199  else
200  {
201  ElemType xx = -x;
202  if (xx < 1.7)
203  z = -(1.5 * xx / (1 + xx));
204  else if (xx < 3)
205  z = -(0.935409070603099 + 0.0458812946797165 * (xx - 1.7));
206  else
207  z = -0.99505475368673;
208  }
209 
210  return 0.5 * (z + 1.0);
211  }
212 
214  size_t inSize;
215 
217  size_t outSize;
218 
220  size_t rho;
221 
223  size_t forwardStep;
224 
226  size_t backwardStep;
227 
229  size_t gradientStep;
230 
232  OutputDataType weights;
233 
235  OutputDataType prevOutput;
236 
238  size_t batchSize;
239 
241  size_t batchStep;
242 
245  size_t gradientStepIdx;
246 
248  OutputDataType cellActivationError;
249 
251  OutputDataType delta;
252 
254  OutputDataType grad;
255 
257  OutputDataType outputParameter;
258 
260  OutputDataType output2GateWeight;
261 
263  OutputDataType input2GateWeight;
264 
266  OutputDataType input2GateBias;
267 
269  OutputDataType gate;
270 
272  OutputDataType gateActivation;
273 
275  OutputDataType stateActivation;
276 
278  OutputDataType cell;
279 
281  OutputDataType cellActivation;
282 
284  OutputDataType forgetGateError;
285 
287  OutputDataType prevError;
288 
290  OutputDataType outParameter;
291 
293  size_t rhoSize;
294 
296  size_t bpttSteps;
297 }; // class FastLSTM
298 
299 } // namespace ann
300 } // namespace mlpack
301 
302 // Include implementation.
303 #include "fast_lstm_impl.hpp"
304 
305 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: fast_lstm.hpp:158
OutputDataType & Delta()
Modify the delta.
Definition: fast_lstm.hpp:153
void Forward(InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType::elem_type ElemType
Definition: fast_lstm.hpp:70
OutputDataType const & Parameters() const
Get the parameters.
Definition: fast_lstm.hpp:141
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:156
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:138
FastLSTM()
Create the Fast LSTM object.
OutputDataType const & Delta() const
Get the delta.
Definition: fast_lstm.hpp:151
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: fast_lstm.hpp:146
InputDataType::elem_type InputElemType
Definition: fast_lstm.hpp:69
void Backward(const InputType &&input, ErrorType &&gy, GradientType &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: fast_lstm.hpp:148
void ResetCell(const size_t size)
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:136
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: fast_lstm.hpp:143
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:65