fast_lstm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <limits>
18 
19 namespace mlpack {
20 namespace ann {
21 
57 template <
58  typename InputDataType = arma::mat,
59  typename OutputDataType = arma::mat
60 >
61 class FastLSTM
62 {
63  public:
64  // Convenience typedefs.
65  typedef typename InputDataType::elem_type InputElemType;
66  typedef typename OutputDataType::elem_type ElemType;
67 
69  FastLSTM();
70 
78  FastLSTM(const size_t inSize,
79  const size_t outSize,
80  const size_t rho = std::numeric_limits<size_t>::max());
81 
89  template<typename InputType, typename OutputType>
90  void Forward(InputType&& input, OutputType&& output);
91 
101  template<typename InputType, typename ErrorType, typename GradientType>
102  void Backward(const InputType&& input,
103  ErrorType&& gy,
104  GradientType&& g);
105 
106  /*
107  * Reset the layer parameter.
108  */
109  void Reset();
110 
111  /*
112  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
113  * new one.
114  *
115  * @param size The current maximum number of steps through time.
116  */
117  void ResetCell(const size_t size);
118 
119  /*
120  * Calculate the gradient using the output delta and the input activation.
121  *
122  * @param input The input parameter used for calculating the gradient.
123  * @param error The calculated error.
124  * @param gradient The calculated gradient.
125  */
126  template<typename InputType, typename ErrorType, typename GradientType>
127  void Gradient(InputType&& input,
128  ErrorType&& error,
129  GradientType&& gradient);
130 
132  size_t Rho() const { return rho; }
134  size_t& Rho() { return rho; }
135 
137  OutputDataType const& Parameters() const { return weights; }
139  OutputDataType& Parameters() { return weights; }
140 
142  InputDataType const& InputParameter() const { return inputParameter; }
144  InputDataType& InputParameter() { return inputParameter; }
145 
147  OutputDataType const& OutputParameter() const { return outputParameter; }
149  OutputDataType& OutputParameter() { return outputParameter; }
150 
152  OutputDataType const& Delta() const { return delta; }
154  OutputDataType& Delta() { return delta; }
155 
157  OutputDataType const& Gradient() const { return grad; }
159  OutputDataType& Gradient() { return grad; }
160 
164  template<typename Archive>
165  void serialize(Archive& ar, const unsigned int /* version */);
166 
167  private:
174  template<typename InputType, typename OutputType>
175  void FastSigmoid(InputType&& input, OutputType&& sigmoids)
176  {
177  for (size_t i = 0; i < input.n_elem; ++i)
178  sigmoids(i) = FastSigmoid(input(i));
179  }
180 
187  ElemType FastSigmoid(const InputElemType data)
188  {
189  ElemType x = 0.5 * data;
190  ElemType z;
191  if (x >= 0)
192  {
193  if (x < 1.7)
194  z = (1.5 * x / (1 + x));
195  else if (x < 3)
196  z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7));
197  else
198  z = 0.99505475368673;
199  }
200  else
201  {
202  ElemType xx = -x;
203  if (xx < 1.7)
204  z = -(1.5 * xx / (1 + xx));
205  else if (xx < 3)
206  z = -(0.935409070603099 + 0.0458812946797165 * (xx - 1.7));
207  else
208  z = -0.99505475368673;
209  }
210 
211  return 0.5 * (z + 1.0);
212  }
213 
215  size_t inSize;
216 
218  size_t outSize;
219 
221  size_t rho;
222 
224  size_t forwardStep;
225 
227  size_t backwardStep;
228 
230  size_t gradientStep;
231 
233  OutputDataType weights;
234 
236  OutputDataType prevOutput;
237 
239  size_t batchSize;
240 
242  size_t batchStep;
243 
246  size_t gradientStepIdx;
247 
249  OutputDataType cellActivationError;
250 
252  OutputDataType delta;
253 
255  OutputDataType grad;
256 
258  InputDataType inputParameter;
259 
261  OutputDataType outputParameter;
262 
264  OutputDataType output2GateWeight;
265 
267  OutputDataType input2GateWeight;
268 
270  OutputDataType input2GateBias;
271 
273  OutputDataType gate;
274 
276  OutputDataType gateActivation;
277 
279  OutputDataType stateActivation;
280 
282  OutputDataType cell;
283 
285  OutputDataType cellActivation;
286 
288  OutputDataType forgetGateError;
289 
291  OutputDataType prevError;
292 
294  OutputDataType outParameter;
295 
297  size_t rhoSize;
298 
300  size_t bpttSteps;
301 }; // class FastLSTM
302 
303 } // namespace ann
304 } // namespace mlpack
305 
306 // Include implementation.
307 #include "fast_lstm_impl.hpp"
308 
309 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: fast_lstm.hpp:159
InputDataType const & InputParameter() const
Get the input parameter.
Definition: fast_lstm.hpp:142
OutputDataType & Delta()
Modify the delta.
Definition: fast_lstm.hpp:154
void Forward(InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType::elem_type ElemType
Definition: fast_lstm.hpp:66
OutputDataType const & Parameters() const
Get the parameters.
Definition: fast_lstm.hpp:137
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:157
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:134
FastLSTM()
Create the Fast LSTM object.
OutputDataType const & Delta() const
Get the delta.
Definition: fast_lstm.hpp:152
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: fast_lstm.hpp:147
InputDataType::elem_type InputElemType
Definition: fast_lstm.hpp:65
void Backward(const InputType &&input, ErrorType &&gy, GradientType &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: fast_lstm.hpp:149
void ResetCell(const size_t size)
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:132
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: fast_lstm.hpp:139
InputDataType & InputParameter()
Modify the input parameter.
Definition: fast_lstm.hpp:144
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:61