copy.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_AUGMENTED_TASKS_COPY_HPP
15 #define MLPACK_METHODS_AUGMENTED_TASKS_COPY_HPP
16 
17 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace ann /* Artificial Neural Network */ {
22 namespace augmented /* Augmented neural network */ {
23 namespace tasks /* Task utilities for augmented */ {
24 
48 class CopyTask
49 {
50  public:
60  CopyTask(const size_t maxLength,
61  const size_t nRepeats,
62  const bool addSeparator = false);
70  void Generate(arma::field<arma::mat>& input,
71  arma::field<arma::mat>& labels,
72  const size_t batchSize,
73  bool fixedLength = false) const;
74 
83  void Generate(arma::mat& input,
84  arma::mat& labels,
85  const size_t batchSize) const;
86 
87  private:
88  // Maximum length of a sequence.
89  size_t maxLength;
90  // Number of repeats the model has to perform to complete the task.
91  size_t nRepeats;
92  // Flag indicating whether generator should produce
93  // separator as part of the sequence
94  bool addSeparator;
95 };
96 
97 } // namespace tasks
98 } // namespace augmented
99 } // namespace ann
100 } // namespace mlpack
101 
102 #include "copy_impl.hpp"
103 
104 #endif
strip_type.hpp
Definition: add_to_po.hpp:21
void Generate(arma::field< arma::mat > &input, arma::field< arma::mat > &labels, const size_t batchSize, bool fixedLength=false) const
Generate dataset of a given size.
The core includes that mlpack expects; standard C++ includes and Armadillo.
CopyTask(const size_t maxLength, const size_t nRepeats, const bool addSeparator=false)
Creates an instance of the sequence copy task.
Generator of instances of the binary sequence copy task.
Definition: copy.hpp:48