fast_lstm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <limits>
18 
19 namespace mlpack {
20 namespace ann {
21 
62 template <
63  typename InputDataType = arma::mat,
64  typename OutputDataType = arma::mat
65 >
66 class FastLSTM
67 {
68  public:
69  // Convenience typedefs.
70  typedef typename InputDataType::elem_type InputElemType;
71  typedef typename OutputDataType::elem_type ElemType;
72 
74  FastLSTM();
75 
83  FastLSTM(const size_t inSize,
84  const size_t outSize,
85  const size_t rho = std::numeric_limits<size_t>::max());
86 
94  template<typename InputType, typename OutputType>
95  void Forward(const InputType& input, OutputType& output);
96 
106  template<typename InputType, typename ErrorType, typename GradientType>
107  void Backward(const InputType& input,
108  const ErrorType& gy,
109  GradientType& g);
110 
111  /*
112  * Reset the layer parameter.
113  */
114  void Reset();
115 
116  /*
117  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
118  * new one.
119  *
120  * @param size The current maximum number of steps through time.
121  */
122  void ResetCell(const size_t size);
123 
124  /*
125  * Calculate the gradient using the output delta and the input activation.
126  *
127  * @param input The input parameter used for calculating the gradient.
128  * @param error The calculated error.
129  * @param gradient The calculated gradient.
130  */
131  template<typename InputType, typename ErrorType, typename GradientType>
132  void Gradient(const InputType& input,
133  const ErrorType& error,
134  GradientType& gradient);
135 
137  size_t Rho() const { return rho; }
139  size_t& Rho() { return rho; }
140 
142  OutputDataType const& Parameters() const { return weights; }
144  OutputDataType& Parameters() { return weights; }
145 
147  OutputDataType const& OutputParameter() const { return outputParameter; }
149  OutputDataType& OutputParameter() { return outputParameter; }
150 
152  OutputDataType const& Delta() const { return delta; }
154  OutputDataType& Delta() { return delta; }
155 
157  OutputDataType const& Gradient() const { return grad; }
159  OutputDataType& Gradient() { return grad; }
160 
162  size_t InSize() const { return inSize; }
163 
165  size_t OutSize() const { return outSize; }
166 
170  template<typename Archive>
171  void serialize(Archive& ar, const unsigned int /* version */);
172 
173  private:
180  template<typename InputType, typename OutputType>
181  void FastSigmoid(const InputType& input, OutputType& sigmoids)
182  {
183  for (size_t i = 0; i < input.n_elem; ++i)
184  sigmoids(i) = FastSigmoid(input(i));
185  }
186 
193  ElemType FastSigmoid(const InputElemType data)
194  {
195  ElemType x = 0.5 * data;
196  ElemType z;
197  if (x >= 0)
198  {
199  if (x < 1.7)
200  z = (1.5 * x / (1 + x));
201  else if (x < 3)
202  z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7));
203  else
204  z = 0.99505475368673;
205  }
206  else
207  {
208  ElemType xx = -x;
209  if (xx < 1.7)
210  z = -(1.5 * xx / (1 + xx));
211  else if (xx < 3)
212  z = -(0.935409070603099 + 0.0458812946797165 * (xx - 1.7));
213  else
214  z = -0.99505475368673;
215  }
216 
217  return 0.5 * (z + 1.0);
218  }
219 
221  size_t inSize;
222 
224  size_t outSize;
225 
227  size_t rho;
228 
230  size_t forwardStep;
231 
233  size_t backwardStep;
234 
236  size_t gradientStep;
237 
239  OutputDataType weights;
240 
242  OutputDataType prevOutput;
243 
245  size_t batchSize;
246 
248  size_t batchStep;
249 
252  size_t gradientStepIdx;
253 
255  OutputDataType cellActivationError;
256 
258  OutputDataType delta;
259 
261  OutputDataType grad;
262 
264  OutputDataType outputParameter;
265 
267  OutputDataType output2GateWeight;
268 
270  OutputDataType input2GateWeight;
271 
273  OutputDataType input2GateBias;
274 
276  OutputDataType gate;
277 
279  OutputDataType gateActivation;
280 
282  OutputDataType stateActivation;
283 
285  OutputDataType cell;
286 
288  OutputDataType cellActivation;
289 
291  OutputDataType forgetGateError;
292 
294  OutputDataType prevError;
295 
297  OutputDataType outParameter;
298 
300  size_t rhoSize;
301 
303  size_t bpttSteps;
304 }; // class FastLSTM
305 
306 } // namespace ann
307 } // namespace mlpack
308 
309 // Include implementation.
310 #include "fast_lstm_impl.hpp"
311 
312 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: fast_lstm.hpp:159
OutputDataType & Delta()
Modify the delta.
Definition: fast_lstm.hpp:154
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType::elem_type ElemType
Definition: fast_lstm.hpp:71
OutputDataType const & Parameters() const
Get the parameters.
Definition: fast_lstm.hpp:142
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:157
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:139
FastLSTM()
Create the Fast LSTM object.
OutputDataType const & Delta() const
Get the delta.
Definition: fast_lstm.hpp:152
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: fast_lstm.hpp:147
InputDataType::elem_type InputElemType
Definition: fast_lstm.hpp:70
size_t OutSize() const
Get the number of output units.
Definition: fast_lstm.hpp:165
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: fast_lstm.hpp:149
void ResetCell(const size_t size)
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:137
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: fast_lstm.hpp:144
size_t InSize() const
Get the number of input units.
Definition: fast_lstm.hpp:162
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:66
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...