recurrent.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
14 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
15 
16 #include <mlpack/core.hpp>
17 
18 #include "../visitor/delete_visitor.hpp"
19 #include "../visitor/delta_visitor.hpp"
20 #include "../visitor/copy_visitor.hpp"
21 #include "../visitor/output_parameter_visitor.hpp"
22 
23 #include "layer_types.hpp"
24 #include "add_merge.hpp"
25 #include "sequential.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
39 template <
40  typename InputDataType = arma::mat,
41  typename OutputDataType = arma::mat,
42  typename... CustomLayers
43 >
44 class Recurrent
45 {
46  public:
51  Recurrent();
52 
54  Recurrent(const Recurrent&);
55 
65  template<typename StartModuleType,
66  typename InputModuleType,
67  typename FeedbackModuleType,
68  typename TransferModuleType>
69  Recurrent(const StartModuleType& start,
70  const InputModuleType& input,
71  const FeedbackModuleType& feedback,
72  const TransferModuleType& transfer,
73  const size_t rho);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& /* input */,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
99  /*
100  * Calculate the gradient using the output delta and the input activation.
101  *
102  * @param input The input parameter used for calculating the gradient.
103  * @param error The calculated error.
104  * @param gradient The calculated gradient.
105  */
106  template<typename eT>
107  void Gradient(const arma::Mat<eT>& input,
108  const arma::Mat<eT>& error,
109  arma::Mat<eT>& /* gradient */);
110 
112  std::vector<LayerTypes<CustomLayers...> >& Model() { return network; }
113 
115  bool Deterministic() const { return deterministic; }
117  bool& Deterministic() { return deterministic; }
118 
120  OutputDataType const& Parameters() const { return parameters; }
122  OutputDataType& Parameters() { return parameters; }
123 
125  OutputDataType const& OutputParameter() const { return outputParameter; }
127  OutputDataType& OutputParameter() { return outputParameter; }
128 
130  OutputDataType const& Delta() const { return delta; }
132  OutputDataType& Delta() { return delta; }
133 
135  OutputDataType const& Gradient() const { return gradient; }
137  OutputDataType& Gradient() { return gradient; }
138 
140  size_t const& Rho() const { return rho; }
141 
145  template<typename Archive>
146  void serialize(Archive& ar, const unsigned int /* version */);
147 
148  private:
150  DeleteVisitor deleteVisitor;
151 
153  CopyVisitor<CustomLayers...> copyVisitor;
154 
156  LayerTypes<CustomLayers...> startModule;
157 
159  LayerTypes<CustomLayers...> inputModule;
160 
162  LayerTypes<CustomLayers...> feedbackModule;
163 
165  LayerTypes<CustomLayers...> transferModule;
166 
168  size_t rho;
169 
171  size_t forwardStep;
172 
174  size_t backwardStep;
175 
177  size_t gradientStep;
178 
180  bool deterministic;
181 
184  bool ownsLayer;
185 
187  OutputDataType parameters;
188 
190  LayerTypes<CustomLayers...> initialModule;
191 
193  LayerTypes<CustomLayers...> recurrentModule;
194 
196  std::vector<LayerTypes<CustomLayers...> > network;
197 
199  LayerTypes<CustomLayers...> mergeModule;
200 
202  DeltaVisitor deltaVisitor;
203 
205  OutputParameterVisitor outputParameterVisitor;
206 
208  std::vector<arma::mat> feedbackOutputParameter;
209 
211  OutputDataType delta;
212 
214  OutputDataType gradient;
215 
217  OutputDataType outputParameter;
218 
220  arma::mat recurrentError;
221 }; // class Recurrent
222 
223 } // namespace ann
224 } // namespace mlpack
225 
226 // Include implementation.
227 #include "recurrent_impl.hpp"
228 
229 #endif
DeleteVisitor executes the destructor of the instantiated object.
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent.hpp:130
Linear algebra utility functions, generally performed on matrices or vectors.
std::vector< LayerTypes< CustomLayers... > > & Model()
Get the model modules.
Definition: recurrent.hpp:112
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent.hpp:117
This visitor is to support copy constructor for neural network module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent.hpp:120
size_t const & Rho() const
Get the number of steps to backpropagate through time.
Definition: recurrent.hpp:140
OutputDataType & Delta()
Modify the delta.
Definition: recurrent.hpp:132
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent.hpp:137
OutputParameterVisitor exposes the output parameter of the given module.
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent.hpp:122
Recurrent()
Default constructor—this will create a Recurrent object that can&#39;t be used, so be careful! Make sure...
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
DeltaVisitor exposes the delta parameter of the given module.
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent.hpp:115
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent.hpp:125
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent.hpp:127