gan.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP
12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP
13 
14 #include <mlpack/core.hpp>
15 
23 
24 
25 namespace mlpack {
26 namespace ann {
27 
57 template<
58  typename Model,
59  typename InitializationRuleType,
60  typename Noise,
61  typename PolicyType = StandardGAN
62 >
63 class GAN
64 {
65  public:
80  GAN(arma::mat& trainData,
81  Model generator,
82  Model discriminator,
83  InitializationRuleType& initializeRule,
84  Noise& noiseFunction,
85  const size_t noiseDim,
86  const size_t batchSize,
87  const size_t generatorUpdateStep,
88  const size_t preTrainSize,
89  const double multiplier,
90  const double clippingParameter = 0.01,
91  const double lambda = 10.0);
92 
94  GAN(const GAN&);
95 
97  GAN(GAN&&);
98 
99  // Reset function.
100  void Reset();
101 
107  template<typename OptimizerType>
108  double Train(OptimizerType& Optimizer);
109 
119  template<typename Policy = PolicyType>
120  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
121  std::is_same<Policy, DCGAN>::value, double>::type
122  Evaluate(const arma::mat& parameters,
123  const size_t i,
124  const size_t batchSize);
125 
134  template<typename Policy = PolicyType>
135  typename std::enable_if<std::is_same<Policy, WGAN>::value,
136  double>::type
137  Evaluate(const arma::mat& parameters,
138  const size_t i,
139  const size_t batchSize);
140 
149  template<typename Policy = PolicyType>
150  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
151  double>::type
152  Evaluate(const arma::mat& parameters,
153  const size_t i,
154  const size_t batchSize);
155 
166  template<typename GradType, typename Policy = PolicyType>
167  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
168  std::is_same<Policy, DCGAN>::value, double>::type
169  EvaluateWithGradient(const arma::mat& parameters,
170  const size_t i,
171  GradType& gradient,
172  const size_t batchSize);
173 
184  template<typename GradType, typename Policy = PolicyType>
185  typename std::enable_if<std::is_same<Policy, WGAN>::value,
186  double>::type
187  EvaluateWithGradient(const arma::mat& parameters,
188  const size_t i,
189  GradType& gradient,
190  const size_t batchSize);
191 
202  template<typename GradType, typename Policy = PolicyType>
203  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
204  double>::type
205  EvaluateWithGradient(const arma::mat& parameters,
206  const size_t i,
207  GradType& gradient,
208  const size_t batchSize);
209 
220  template<typename Policy = PolicyType>
221  typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
222  std::is_same<Policy, DCGAN>::value, void>::type
223  Gradient(const arma::mat& parameters,
224  const size_t i,
225  arma::mat& gradient,
226  const size_t batchSize);
227 
238  template<typename Policy = PolicyType>
239  typename std::enable_if<std::is_same<Policy, WGAN>::value, void>::type
240  Gradient(const arma::mat& parameters,
241  const size_t i,
242  arma::mat& gradient,
243  const size_t batchSize);
244 
255  template<typename Policy = PolicyType>
256  typename std::enable_if<std::is_same<Policy, WGANGP>::value,
257  void>::type
258  Gradient(const arma::mat& parameters,
259  const size_t i,
260  arma::mat& gradient,
261  const size_t batchSize);
262 
267  void Shuffle();
268 
274  void Forward(arma::mat&& input);
275 
282  void Predict(arma::mat&& input,
283  arma::mat& output);
284 
286  const arma::mat& Parameters() const { return parameter; }
288  arma::mat& Parameters() { return parameter; }
289 
291  const Model& Generator() const { return generator; }
293  Model& Generator() { return generator; }
295  const Model& Discriminator() const { return discriminator; }
297  Model& Discriminator() { return discriminator; }
298 
300  size_t NumFunctions() const { return numFunctions; }
301 
303  const arma::mat& Responses() const { return responses; }
305  arma::mat& Responses() { return responses; }
306 
308  const arma::mat& Predictors() const { return predictors; }
310  arma::mat& Predictors() { return predictors; }
311 
313  template<typename Archive>
314  void serialize(Archive& ar, const unsigned int /* version */);
315 
316  private:
318  arma::mat predictors;
320  arma::mat parameter;
322  Model generator;
324  Model discriminator;
326  InitializationRuleType initializeRule;
328  Noise noiseFunction;
330  size_t noiseDim;
332  size_t numFunctions;
334  size_t batchSize;
336  size_t counter;
338  size_t currentBatch;
340  size_t generatorUpdateStep;
342  size_t preTrainSize;
344  double multiplier;
346  double clippingParameter;
348  double lambda;
350  bool reset;
352  DeltaVisitor deltaVisitor;
354  arma::mat responses;
356  arma::mat currentInput;
358  arma::mat currentTarget;
360  OutputParameterVisitor outputParameterVisitor;
362  WeightSizeVisitor weightSizeVisitor;
364  ResetVisitor resetVisitor;
366  arma::mat gradient;
368  arma::mat gradientDiscriminator;
370  arma::mat noiseGradientDiscriminator;
372  arma::mat normGradientDiscriminator;
374  arma::mat noise;
376  arma::mat gradientGenerator;
378  arma::mat ganOutput;
379 };
380 
381 } // namespace ann
382 } // namespace mlpack
383 
384 // Include implementation.
385 #include "gan_impl.hpp"
386 #include "wgan_impl.hpp"
387 #include "wgangp_impl.hpp"
388 
389 
390 #endif
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
Definition: gan.hpp:295
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: gan.hpp:300
Model & Generator()
Modify the generator of the GAN.
Definition: gan.hpp:293
.hpp
Definition: add_to_po.hpp:21
WeightSizeVisitor returns the number of weights of the given module.
void Forward(arma::mat &&input)
This function does a forward pass through the GAN network.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
void Shuffle()
Shuffle the order of function visitation.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
Definition: gan.hpp:288
ResetVisitor executes the Reset() function.
double Train(OptimizerType &Optimizer)
Train function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
Definition: gan.hpp:303
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat &parameters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
Definition: gan.hpp:310
void Predict(arma::mat &&input, arma::mat &output)
This function predicts the output of the network on the given input.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
Definition: gan.hpp:291
const arma::mat & Parameters() const
Return the parameters of the network.
Definition: gan.hpp:286
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
Definition: gan.hpp:63
arma::mat & Responses()
Modify the matrix of responses to the input data points.
Definition: gan.hpp:305
GAN(arma::mat &trainData, Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
Model & Discriminator()
Modify the discriminator of the GAN.
Definition: gan.hpp:297
const arma::mat & Predictors() const
Get the matrix of data points (predictors).
Definition: gan.hpp:308