rnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_RNN_HPP
13 #define MLPACK_METHODS_ANN_RNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
21 
23 
28 
29 namespace mlpack {
30 namespace ann {
31 
38 template<
39  typename OutputLayerType = NegativeLogLikelihood<>,
40  typename InitializationRuleType = RandomInitialization,
41  typename... CustomLayers
42 >
43 class RNN
44 {
45  public:
47  using NetworkType = RNN<OutputLayerType,
48  InitializationRuleType,
49  CustomLayers...>;
50 
66  RNN(const size_t rho,
67  const bool single = false,
68  OutputLayerType outputLayer = OutputLayerType(),
69  InitializationRuleType initializeRule = InitializationRuleType());
70 
72  ~RNN();
73 
97  template<typename OptimizerType>
98  void Train(arma::cube predictors,
99  arma::cube responses,
100  OptimizerType& optimizer);
101 
125  template<typename OptimizerType = mlpack::optimization::StandardSGD>
126  void Train(arma::cube predictors, arma::cube responses);
127 
147  void Predict(arma::cube predictors,
148  arma::cube& results,
149  const size_t batchSize = 256);
150 
163  double Evaluate(const arma::mat& parameters,
164  const size_t begin,
165  const size_t batchSize,
166  const bool deterministic);
167 
179  double Evaluate(const arma::mat& parameters,
180  const size_t begin,
181  const size_t batchSize)
182  {
183  return Evaluate(parameters, begin, batchSize, true);
184  }
185 
199  void Gradient(const arma::mat& parameters,
200  const size_t begin,
201  arma::mat& gradient,
202  const size_t batchSize);
203 
208  void Shuffle();
209 
210  /*
211  * Add a new module to the model.
212  *
213  * @param args The layer parameter.
214  */
215  template <class LayerType, class... Args>
216  void Add(Args... args) { network.push_back(new LayerType(args...)); }
217 
218  /*
219  * Add a new module to the model.
220  *
221  * @param layer The Layer to be added to the model.
222  */
223  void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
224 
226  size_t NumFunctions() const { return numFunctions; }
227 
229  const arma::mat& Parameters() const { return parameter; }
231  arma::mat& Parameters() { return parameter; }
232 
234  const size_t& Rho() const { return rho; }
236  size_t& Rho() { return rho; }
237 
239  const arma::cube& Responses() const { return responses; }
241  arma::cube& Responses() { return responses; }
242 
244  const arma::cube& Predictors() const { return predictors; }
246  arma::cube& Predictors() { return predictors; }
247 
253  void Reset();
254 
258  void ResetParameters();
259 
261  template<typename Archive>
262  void serialize(Archive& ar, const unsigned int /* version */);
263 
264  private:
265  // Helper functions.
272  void Forward(arma::mat&& input);
273 
277  void ResetCells();
278 
283  void Backward();
284 
289  template<typename InputType>
290  void Gradient(InputType&& input);
291 
296  void ResetDeterministic();
297 
301  void ResetGradients(arma::mat& gradient);
302 
304  size_t rho;
305 
307  OutputLayerType outputLayer;
308 
311  InitializationRuleType initializeRule;
312 
314  size_t inputSize;
315 
317  size_t outputSize;
318 
320  size_t targetSize;
321 
323  bool reset;
324 
326  bool single;
327 
329  std::vector<LayerTypes<CustomLayers...> > network;
330 
332  arma::cube predictors;
333 
335  arma::cube responses;
336 
338  arma::mat parameter;
339 
341  size_t numFunctions;
342 
344  arma::mat error;
345 
347  DeltaVisitor deltaVisitor;
348 
350  OutputParameterVisitor outputParameterVisitor;
351 
353  std::vector<arma::mat> moduleOutputParameter;
354 
356  WeightSizeVisitor weightSizeVisitor;
357 
359  ResetVisitor resetVisitor;
360 
362  DeleteVisitor deleteVisitor;
363 
365  bool deterministic;
366 
368  arma::mat currentGradient;
369 }; // class RNN
370 
371 } // namespace ann
372 } // namespace mlpack
373 
376 namespace boost {
377 namespace serialization {
378 
379 template<typename OutputLayerType,
380  typename InitializationRuleType,
381  typename... CustomLayer>
382 struct version<
383  mlpack::ann::RNN<OutputLayerType, InitializationRuleType, CustomLayer...>>
384 {
385  BOOST_STATIC_CONSTANT(int, value = 1);
386 };
387 
388 } // namespace serialization
389 } // namespace boost
390 
391 // Include implementation.
392 #include "rnn_impl.hpp"
393 
394 #endif
DeleteVisitor executes the destructor of the instantiated object.
RNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the RNN object.
void ResetParameters()
Reset the module information (weights/parameters).
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat > *, LinearNoBias< arma::mat, arma::mat > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, CustomLayers *... > LayerTypes
Set the serialization version of the FFN class.
Definition: ffn.hpp:468
.hpp
Definition: add_to_po.hpp:21
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Reset()
Reset the state of the network.
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: rnn.hpp:236
WeightSizeVisitor returns the number of weights of the given module.
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:229
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: rnn.hpp:231
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the recurrent neural network with the given parameters.
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:43
Implementation of the base layer.
Definition: base_layer.hpp:47
void Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the recurrent neural network on the given input data using the given optimizer.
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize)
Evaluate the recurrent neural network with the given parameters.
Definition: rnn.hpp:179
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: rnn.hpp:246
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: rnn.hpp:244
void Add(Args... args)
Definition: rnn.hpp:216
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rnn.hpp:226
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
DeltaVisitor exposes the delta parameter of the given module.
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: rnn.hpp:234
void Add(LayerTypes< CustomLayers... > layer)
Definition: rnn.hpp:223
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: rnn.hpp:241
~RNN()
Destructor to release allocated memory.
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: rnn.hpp:239
void Shuffle()
Shuffle the order of function visitation.