rnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_RNN_HPP
13 #define MLPACK_METHODS_ANN_RNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
21 
23 
27 
28 #include <ensmallen.hpp>
29 
30 namespace mlpack {
31 namespace ann {
32 
39 template<
40  typename OutputLayerType = NegativeLogLikelihood<>,
41  typename InitializationRuleType = RandomInitialization,
42  typename... CustomLayers
43 >
44 class RNN
45 {
46  public:
48  using NetworkType = RNN<OutputLayerType,
49  InitializationRuleType,
50  CustomLayers...>;
51 
67  RNN(const size_t rho,
68  const bool single = false,
69  OutputLayerType outputLayer = OutputLayerType(),
70  InitializationRuleType initializeRule = InitializationRuleType());
71 
73  ~RNN();
74 
99  template<typename OptimizerType>
100  double Train(arma::cube predictors,
101  arma::cube responses,
102  OptimizerType& optimizer);
103 
128  template<typename OptimizerType = ens::StandardSGD>
129  double Train(arma::cube predictors, arma::cube responses);
130 
150  void Predict(arma::cube predictors,
151  arma::cube& results,
152  const size_t batchSize = 256);
153 
166  double Evaluate(const arma::mat& parameters,
167  const size_t begin,
168  const size_t batchSize,
169  const bool deterministic);
170 
182  double Evaluate(const arma::mat& parameters,
183  const size_t begin,
184  const size_t batchSize);
185 
197  template<typename GradType>
198  double EvaluateWithGradient(const arma::mat& parameters,
199  const size_t begin,
200  GradType& gradient,
201  const size_t batchSize);
202 
216  void Gradient(const arma::mat& parameters,
217  const size_t begin,
218  arma::mat& gradient,
219  const size_t batchSize);
220 
225  void Shuffle();
226 
227  /*
228  * Add a new module to the model.
229  *
230  * @param args The layer parameter.
231  */
232  template <class LayerType, class... Args>
233  void Add(Args... args) { network.push_back(new LayerType(args...)); }
234 
235  /*
236  * Add a new module to the model.
237  *
238  * @param layer The Layer to be added to the model.
239  */
240  void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
241 
243  size_t NumFunctions() const { return numFunctions; }
244 
246  const arma::mat& Parameters() const { return parameter; }
248  arma::mat& Parameters() { return parameter; }
249 
251  const size_t& Rho() const { return rho; }
253  size_t& Rho() { return rho; }
254 
256  const arma::cube& Responses() const { return responses; }
258  arma::cube& Responses() { return responses; }
259 
261  const arma::cube& Predictors() const { return predictors; }
263  arma::cube& Predictors() { return predictors; }
264 
270  void Reset();
271 
275  void ResetParameters();
276 
278  template<typename Archive>
279  void serialize(Archive& ar, const unsigned int /* version */);
280 
281  private:
282  // Helper functions.
289  void Forward(arma::mat&& input);
290 
294  void ResetCells();
295 
300  void Backward();
301 
306  template<typename InputType>
307  void Gradient(InputType&& input);
308 
313  void ResetDeterministic();
314 
318  void ResetGradients(arma::mat& gradient);
319 
321  size_t rho;
322 
324  OutputLayerType outputLayer;
325 
328  InitializationRuleType initializeRule;
329 
331  size_t inputSize;
332 
334  size_t outputSize;
335 
337  size_t targetSize;
338 
340  bool reset;
341 
343  bool single;
344 
346  std::vector<LayerTypes<CustomLayers...> > network;
347 
349  arma::cube predictors;
350 
352  arma::cube responses;
353 
355  arma::mat parameter;
356 
358  size_t numFunctions;
359 
361  arma::mat error;
362 
364  DeltaVisitor deltaVisitor;
365 
367  OutputParameterVisitor outputParameterVisitor;
368 
370  std::vector<arma::mat> moduleOutputParameter;
371 
373  WeightSizeVisitor weightSizeVisitor;
374 
376  ResetVisitor resetVisitor;
377 
379  DeleteVisitor deleteVisitor;
380 
382  bool deterministic;
383 
385  arma::mat currentGradient;
386 
387  // The BRN class should have access to internal members.
388  template<
389  typename OutputLayerType1,
390  typename MergeLayerType1,
391  typename MergeOutputType1,
392  typename InitializationRuleType1,
393  typename... CustomLayers1
394  >
395  friend class BRNN;
396 }; // class RNN
397 
398 } // namespace ann
399 } // namespace mlpack
400 
403 namespace boost {
404 namespace serialization {
405 
406 template<typename OutputLayerType,
407  typename InitializationRuleType,
408  typename... CustomLayer>
409 struct version<
410  mlpack::ann::RNN<OutputLayerType, InitializationRuleType, CustomLayer...>>
411 {
412  BOOST_STATIC_CONSTANT(int, value = 1);
413 };
414 
415 } // namespace serialization
416 } // namespace boost
417 
418 // Include implementation.
419 #include "rnn_impl.hpp"
420 
421 #endif
DeleteVisitor executes the destructor of the instantiated object.
RNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the RNN object.
void ResetParameters()
Reset the module information (weights/parameters).
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the recurrent neural network on the given input data using the given optimizer.
Set the serialization version of the adaboost class.
Definition: adaboost.hpp:180
.hpp
Definition: add_to_po.hpp:21
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Reset()
Reset the state of the network.
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: rnn.hpp:253
WeightSizeVisitor returns the number of weights of the given module.
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:246
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: rnn.hpp:248
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the recurrent neural network with the given parameters.
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:44
Implementation of the base layer.
Definition: base_layer.hpp:49
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: rnn.hpp:263
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat > *, LinearNoBias< arma::mat, arma::mat > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, CustomLayers *... > LayerTypes
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: rnn.hpp:261
void Add(Args... args)
Definition: rnn.hpp:233
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rnn.hpp:243
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:47
double EvaluateWithGradient(const arma::mat &parameters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the recurrent neural network with the given parameters.
DeltaVisitor exposes the delta parameter of the given module.
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: rnn.hpp:251
void Add(LayerTypes< CustomLayers... > layer)
Definition: rnn.hpp:240
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: rnn.hpp:258
~RNN()
Destructor to release allocated memory.
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: rnn.hpp:256
void Shuffle()
Shuffle the order of function visitation.