rnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_RNN_HPP
13 #define MLPACK_METHODS_ANN_RNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
21 
23 
27 
28 #include <ensmallen.hpp>
29 
30 namespace mlpack {
31 namespace ann {
32 
39 template<
40  typename OutputLayerType = NegativeLogLikelihood<>,
41  typename InitializationRuleType = RandomInitialization,
42  typename... CustomLayers
43 >
44 class RNN
45 {
46  public:
48  using NetworkType = RNN<OutputLayerType,
49  InitializationRuleType,
50  CustomLayers...>;
51 
67  RNN(const size_t rho,
68  const bool single = false,
69  OutputLayerType outputLayer = OutputLayerType(),
70  InitializationRuleType initializeRule = InitializationRuleType());
71 
73  ~RNN();
74 
102  template<typename OptimizerType, typename... CallbackTypes>
103  double Train(arma::cube predictors,
104  arma::cube responses,
105  OptimizerType& optimizer,
106  CallbackTypes&&... callbacks);
107 
135  template<typename OptimizerType = ens::StandardSGD, typename... CallbackTypes>
136  double Train(arma::cube predictors,
137  arma::cube responses,
138  CallbackTypes&&... callbacks);
139 
159  void Predict(arma::cube predictors,
160  arma::cube& results,
161  const size_t batchSize = 256);
162 
175  double Evaluate(const arma::mat& parameters,
176  const size_t begin,
177  const size_t batchSize,
178  const bool deterministic);
179 
191  double Evaluate(const arma::mat& parameters,
192  const size_t begin,
193  const size_t batchSize);
194 
206  template<typename GradType>
207  double EvaluateWithGradient(const arma::mat& parameters,
208  const size_t begin,
209  GradType& gradient,
210  const size_t batchSize);
211 
225  void Gradient(const arma::mat& parameters,
226  const size_t begin,
227  arma::mat& gradient,
228  const size_t batchSize);
229 
234  void Shuffle();
235 
236  /*
237  * Add a new module to the model.
238  *
239  * @param args The layer parameter.
240  */
241  template <class LayerType, class... Args>
242  void Add(Args... args) { network.push_back(new LayerType(args...)); }
243 
244  /*
245  * Add a new module to the model.
246  *
247  * @param layer The Layer to be added to the model.
248  */
249  void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
250 
252  size_t NumFunctions() const { return numFunctions; }
253 
255  const arma::mat& Parameters() const { return parameter; }
257  arma::mat& Parameters() { return parameter; }
258 
260  const size_t& Rho() const { return rho; }
262  size_t& Rho() { return rho; }
263 
265  const arma::cube& Responses() const { return responses; }
267  arma::cube& Responses() { return responses; }
268 
270  const arma::cube& Predictors() const { return predictors; }
272  arma::cube& Predictors() { return predictors; }
273 
279  void Reset();
280 
284  void ResetParameters();
285 
287  template<typename Archive>
288  void serialize(Archive& ar, const unsigned int /* version */);
289 
290  private:
291  // Helper functions.
298  void Forward(arma::mat&& input);
299 
303  void ResetCells();
304 
309  void Backward();
310 
315  template<typename InputType>
316  void Gradient(InputType&& input);
317 
322  void ResetDeterministic();
323 
327  void ResetGradients(arma::mat& gradient);
328 
330  size_t rho;
331 
333  OutputLayerType outputLayer;
334 
337  InitializationRuleType initializeRule;
338 
340  size_t inputSize;
341 
343  size_t outputSize;
344 
346  size_t targetSize;
347 
349  bool reset;
350 
352  bool single;
353 
355  std::vector<LayerTypes<CustomLayers...> > network;
356 
358  arma::cube predictors;
359 
361  arma::cube responses;
362 
364  arma::mat parameter;
365 
367  size_t numFunctions;
368 
370  arma::mat error;
371 
373  DeltaVisitor deltaVisitor;
374 
376  OutputParameterVisitor outputParameterVisitor;
377 
379  std::vector<arma::mat> moduleOutputParameter;
380 
382  WeightSizeVisitor weightSizeVisitor;
383 
385  ResetVisitor resetVisitor;
386 
388  DeleteVisitor deleteVisitor;
389 
391  bool deterministic;
392 
394  arma::mat currentGradient;
395 
396  // The BRN class should have access to internal members.
397  template<
398  typename OutputLayerType1,
399  typename MergeLayerType1,
400  typename MergeOutputType1,
401  typename InitializationRuleType1,
402  typename... CustomLayers1
403  >
404  friend class BRNN;
405 }; // class RNN
406 
407 } // namespace ann
408 } // namespace mlpack
409 
412 namespace boost {
413 namespace serialization {
414 
415 template<typename OutputLayerType,
416  typename InitializationRuleType,
417  typename... CustomLayer>
418 struct version<
419  mlpack::ann::RNN<OutputLayerType, InitializationRuleType, CustomLayer...>>
420 {
421  BOOST_STATIC_CONSTANT(int, value = 1);
422 };
423 
424 } // namespace serialization
425 } // namespace boost
426 
427 // Include implementation.
428 #include "rnn_impl.hpp"
429 
430 #endif
DeleteVisitor executes the destructor of the instantiated object.
RNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the RNN object.
void ResetParameters()
Reset the module information (weights/parameters).
Set the serialization version of the adaboost class.
Definition: adaboost.hpp:194
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer, CallbackTypes &&... callbacks)
Train the recurrent neural network on the given input data using the given optimizer.
strip_type.hpp
Definition: add_to_po.hpp:21
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Reset()
Reset the state of the network.
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: rnn.hpp:262
WeightSizeVisitor returns the number of weights of the given module.
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:255
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: rnn.hpp:257
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the recurrent neural network with the given parameters.
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:44
Implementation of the base layer.
Definition: base_layer.hpp:53
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: rnn.hpp:272
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: rnn.hpp:270
void Add(Args... args)
Definition: rnn.hpp:242
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rnn.hpp:252
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:47
double EvaluateWithGradient(const arma::mat &parameters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the recurrent neural network with the given parameters.
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
DeltaVisitor exposes the delta parameter of the given module.
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: rnn.hpp:260
void Add(LayerTypes< CustomLayers... > layer)
Definition: rnn.hpp:249
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: rnn.hpp:267
~RNN()
Destructor to release allocated memory.
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: rnn.hpp:265
void Shuffle()
Shuffle the order of function visitation.