rnn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_RNN_HPP
13 #define MLPACK_METHODS_ANN_RNN_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
21 
23 
27 
28 namespace mlpack {
29 namespace ann {
30 
37 template<
38  typename OutputLayerType = NegativeLogLikelihood<>,
39  typename InitializationRuleType = RandomInitialization,
40  typename... CustomLayers
41 >
42 class RNN
43 {
44  public:
46  using NetworkType = RNN<OutputLayerType,
47  InitializationRuleType,
48  CustomLayers...>;
49 
65  RNN(const size_t rho,
66  const bool single = false,
67  OutputLayerType outputLayer = OutputLayerType(),
68  InitializationRuleType initializeRule = InitializationRuleType());
69 
71  ~RNN();
72 
96  template<typename OptimizerType>
97  void Train(arma::cube predictors,
98  arma::cube responses,
99  OptimizerType& optimizer);
100 
124  template<typename OptimizerType = mlpack::optimization::StandardSGD>
125  void Train(arma::cube predictors, arma::cube responses);
126 
146  void Predict(arma::cube predictors,
147  arma::cube& results,
148  const size_t batchSize = 256);
149 
162  double Evaluate(const arma::mat& parameters,
163  const size_t begin,
164  const size_t batchSize,
165  const bool deterministic);
166 
178  double Evaluate(const arma::mat& parameters,
179  const size_t begin,
180  const size_t batchSize)
181  {
182  return Evaluate(parameters, begin, batchSize, true);
183  }
184 
198  void Gradient(const arma::mat& parameters,
199  const size_t begin,
200  arma::mat& gradient,
201  const size_t batchSize);
202 
207  void Shuffle();
208 
209  /*
210  * Add a new module to the model.
211  *
212  * @param args The layer parameter.
213  */
214  template <class LayerType, class... Args>
215  void Add(Args... args) { network.push_back(new LayerType(args...)); }
216 
217  /*
218  * Add a new module to the model.
219  *
220  * @param layer The Layer to be added to the model.
221  */
222  void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
223 
225  size_t NumFunctions() const { return numFunctions; }
226 
228  const arma::mat& Parameters() const { return parameter; }
230  arma::mat& Parameters() { return parameter; }
231 
233  const size_t& Rho() const { return rho; }
235  size_t& Rho() { return rho; }
236 
238  const arma::cube& Responses() const { return responses; }
240  arma::cube& Responses() { return responses; }
241 
243  const arma::cube& Predictors() const { return predictors; }
245  arma::cube& Predictors() { return predictors; }
246 
252  void Reset();
253 
257  void ResetParameters();
258 
260  template<typename Archive>
261  void serialize(Archive& ar, const unsigned int /* version */);
262 
263  private:
264  // Helper functions.
271  void Forward(arma::mat&& input);
272 
276  void ResetCells();
277 
282  void Backward();
283 
288  template<typename InputType>
289  void Gradient(InputType&& input);
290 
295  void ResetDeterministic();
296 
300  void ResetGradients(arma::mat& gradient);
301 
303  size_t rho;
304 
306  OutputLayerType outputLayer;
307 
310  InitializationRuleType initializeRule;
311 
313  size_t inputSize;
314 
316  size_t outputSize;
317 
319  size_t targetSize;
320 
322  bool reset;
323 
325  bool single;
326 
328  std::vector<LayerTypes<CustomLayers...> > network;
329 
331  arma::cube predictors;
332 
334  arma::cube responses;
335 
337  arma::mat parameter;
338 
340  size_t numFunctions;
341 
343  arma::mat error;
344 
346  DeltaVisitor deltaVisitor;
347 
349  OutputParameterVisitor outputParameterVisitor;
350 
352  std::vector<arma::mat> moduleOutputParameter;
353 
355  WeightSizeVisitor weightSizeVisitor;
356 
358  ResetVisitor resetVisitor;
359 
361  DeleteVisitor deleteVisitor;
362 
364  bool deterministic;
365 
367  arma::mat currentGradient;
368 }; // class RNN
369 
370 } // namespace ann
371 } // namespace mlpack
372 
373 // Include implementation.
374 #include "rnn_impl.hpp"
375 
376 #endif
DeleteVisitor executes the destructor of the instantiated object.
RNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the RNN object.
void ResetParameters()
Reset the module information (weights/parameters).
.hpp
Definition: add_to_po.hpp:21
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Reset()
Reset the state of the network.
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: rnn.hpp:235
WeightSizeVisitor returns the number of weights of the given module.
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:228
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: rnn.hpp:230
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the recurrent neural network with the given parameters.
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:42
void Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the recurrent neural network on the given input data using the given optimizer.
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize)
Evaluate the recurrent neural network with the given parameters.
Definition: rnn.hpp:178
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: rnn.hpp:245
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: rnn.hpp:243
void Add(Args... args)
Definition: rnn.hpp:215
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rnn.hpp:225
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
DeltaVisitor exposes the delta parameter of the given module.
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: rnn.hpp:233
void Add(LayerTypes< CustomLayers... > layer)
Definition: rnn.hpp:222
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: rnn.hpp:240
~RNN()
Destructor to release allocated memory.
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CrossEntropyError< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat > *, LinearNoBias< arma::mat, arma::mat > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MeanSquaredError< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, SigmoidCrossEntropyError< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, CustomLayers *... > LayerTypes
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: rnn.hpp:238
void Shuffle()
Shuffle the order of function visitation.