brnn.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_BRNN_HPP
14 #define MLPACK_METHODS_ANN_BRNN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
20 #include "visitor/copy_visitor.hpp"
23 
28 
29 #include <ensmallen.hpp>
30 
31 namespace mlpack {
32 namespace ann {
33 
40 template<
41  typename OutputLayerType = NegativeLogLikelihood<>,
42  typename MergeLayerType = Concat<>,
43  typename MergeOutputType = LogSoftMax<>,
44  typename InitializationRuleType = RandomInitialization,
45  typename... CustomLayers
46 >
47 class BRNN
48 {
49  public:
51  using NetworkType = BRNN<OutputLayerType,
52  MergeLayerType,
53  MergeOutputType,
54  InitializationRuleType,
55  CustomLayers...>;
56 
72  BRNN(const size_t rho,
73  const bool single = false,
74  OutputLayerType outputLayer = OutputLayerType(),
75  MergeLayerType mergeLayer = MergeLayerType(),
76  MergeOutputType mergeOutput = MergeOutputType(),
77  InitializationRuleType initializeRule = InitializationRuleType());
78 
102  template<typename OptimizerType>
103  double Train(arma::cube predictors,
104  arma::cube responses,
105  OptimizerType& optimizer);
106 
130  template<typename OptimizerType = ens::StandardSGD>
131  double Train(arma::cube predictors, arma::cube responses);
132 
152  void Predict(arma::cube predictors,
153  arma::cube& results,
154  const size_t batchSize = 256);
155 
169  double Evaluate(const arma::mat& parameters,
170  const size_t begin,
171  const size_t batchSize,
172  const bool deterministic);
173 
186  double Evaluate(const arma::mat& parameters,
187  const size_t begin,
188  const size_t batchSize);
189 
203  template<typename GradType>
204  double EvaluateWithGradient(const arma::mat& parameters,
205  const size_t begin,
206  GradType& gradient,
207  const size_t batchSize);
208 
222  void Gradient(const arma::mat& parameters,
223  const size_t begin,
224  arma::mat& gradient,
225  const size_t batchSize);
226 
231  void Shuffle();
232 
233  /*
234  * Add a new module to the model.
235  *
236  * @param args The layer parameter.
237  */
238  template <class LayerType, class... Args>
239  void Add(Args... args);
240 
241  /*
242  * Add a new module to the model.
243  *
244  * @param layer The Layer to be added to the model.
245  */
246  void Add(LayerTypes<CustomLayers...> layer);
247 
249  size_t NumFunctions() const { return numFunctions; }
250 
252  const arma::mat& Parameters() const { return parameter; }
254  arma::mat& Parameters() { return parameter; }
255 
257  const size_t& Rho() const { return rho; }
259  size_t& Rho() { return rho; }
260 
262  const arma::cube& Responses() const { return responses; }
264  arma::cube& Responses() { return responses; }
265 
267  const arma::cube& Predictors() const { return predictors; }
269  arma::cube& Predictors() { return predictors; }
270 
276  void Reset();
277 
281  void ResetParameters();
282 
284  template<typename Archive>
285  void serialize(Archive& ar, const unsigned int /* version */);
286 
287  private:
288  // Helper functions.
293  void ResetDeterministic();
294 
296  size_t rho;
297 
299  OutputLayerType outputLayer;
300 
302  LayerTypes<CustomLayers...> mergeLayer;
303 
305  LayerTypes<CustomLayers...> mergeOutput;
306 
309  InitializationRuleType initializeRule;
310 
312  size_t inputSize;
313 
315  size_t outputSize;
316 
318  size_t targetSize;
319 
321  bool reset;
322 
324  bool single;
325 
327  arma::cube predictors;
328 
330  arma::cube responses;
331 
333  arma::mat parameter;
334 
336  size_t numFunctions;
337 
339  arma::mat error;
340 
342  DeltaVisitor deltaVisitor;
343 
345  OutputParameterVisitor outputParameterVisitor;
346 
348  std::vector<arma::mat> forwardRNNOutputParameter;
349 
351  std::vector<arma::mat> backwardRNNOutputParameter;
352 
354  WeightSizeVisitor weightSizeVisitor;
355 
357  ResetVisitor resetVisitor;
358 
360  DeleteVisitor deleteVisitor;
361 
363  CopyVisitor<CustomLayers...> copyVisitor;
364 
366  bool deterministic;
367 
369  arma::mat forwardGradient;
370 
372  arma::mat backwardGradient;
373 
375  arma::mat totalGradient;
376 
378  RNN<OutputLayerType, InitializationRuleType, CustomLayers...> forwardRNN;
379 
381  RNN<OutputLayerType, InitializationRuleType, CustomLayers...> backwardRNN;
382 }; // class BRNN
383 
384 } // namespace ann
385 } // namespace mlpack
386 
388 namespace boost {
389 namespace serialization {
390 
391 template<typename OutputLayerType,
392  typename InitializationRuleType,
393  typename MergeLayerType,
394  typename MergeOutputType,
395  typename... CustomLayer>
396 struct version<
397  mlpack::ann::BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
398  InitializationRuleType, CustomLayer...>>
399 {
400  BOOST_STATIC_CONSTANT(int, value = 1);
401 };
402 
403 } // namespace serialization
404 } // namespace boost
405 
406 // Include implementation.
407 #include "brnn_impl.hpp"
408 
409 #endif
DeleteVisitor executes the destructor of the instantiated object.
void ResetParameters()
Reset the module information (weights/parameters).
Set the serialization version of the adaboost class.
Definition: adaboost.hpp:194
strip_type.hpp
Definition: add_to_po.hpp:21
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: brnn.hpp:254
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: brnn.hpp:257
void Add(Args... args)
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the bidirectional recurrent neural network on the given input data using the given optimizer...
This visitor is to support copy constructor for neural network module.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double EvaluateWithGradient(const arma::mat &parameters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the bidirectional recurrent neural network with the given parameters.
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the bidirectional recurrent neural network with the given parameters...
WeightSizeVisitor returns the number of weights of the given module.
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: brnn.hpp:264
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: brnn.hpp:252
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:44
Implementation of the base layer.
Definition: base_layer.hpp:53
BRNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), MergeLayerType mergeLayer=MergeLayerType(), MergeOutputType mergeOutput=MergeOutputType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the BRNN object.
void Reset()
Reset the state of the network.
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: brnn.hpp:267
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the bidirectional recurrent neural network with the given parameters.
void Shuffle()
Shuffle the order of function visitation.
size_t NumFunctions() const
Return the number of separable functions. (number of predictor points).
Definition: brnn.hpp:249
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: brnn.hpp:262
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:47
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
void serialize(Archive &ar, const unsigned int)
Serialize the model.
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: brnn.hpp:269
DeltaVisitor exposes the delta parameter of the given module.
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: brnn.hpp:259