11 #ifndef MLPACK_METHODS_ANN_GAN_GAN_HPP 12 #define MLPACK_METHODS_ANN_GAN_GAN_HPP 59 typename InitializationRuleType,
61 typename PolicyType = StandardGAN
80 GAN(arma::mat& trainData,
83 InitializationRuleType& initializeRule,
85 const size_t noiseDim,
86 const size_t batchSize,
87 const size_t generatorUpdateStep,
88 const size_t preTrainSize,
89 const double multiplier,
90 const double clippingParameter = 0.01,
91 const double lambda = 10.0);
107 template<
typename OptimizerType>
108 double Train(OptimizerType& Optimizer);
119 template<
typename Policy = PolicyType>
120 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
121 std::is_same<Policy, DCGAN>::value,
double>::type
122 Evaluate(
const arma::mat& parameters,
124 const size_t batchSize);
134 template<
typename Policy = PolicyType>
135 typename std::enable_if<std::is_same<Policy, WGAN>::value,
137 Evaluate(
const arma::mat& parameters,
139 const size_t batchSize);
149 template<
typename Policy = PolicyType>
150 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
152 Evaluate(
const arma::mat& parameters,
154 const size_t batchSize);
166 template<
typename GradType,
typename Policy = PolicyType>
167 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
168 std::is_same<Policy, DCGAN>::value,
double>::type
172 const size_t batchSize);
184 template<
typename GradType,
typename Policy = PolicyType>
185 typename std::enable_if<std::is_same<Policy, WGAN>::value,
190 const size_t batchSize);
202 template<
typename GradType,
typename Policy = PolicyType>
203 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
208 const size_t batchSize);
220 template<
typename Policy = PolicyType>
221 typename std::enable_if<std::is_same<Policy, StandardGAN>::value ||
222 std::is_same<Policy, DCGAN>::value,
void>::type
223 Gradient(
const arma::mat& parameters,
226 const size_t batchSize);
238 template<
typename Policy = PolicyType>
239 typename std::enable_if<std::is_same<Policy, WGAN>::value,
void>::type
240 Gradient(
const arma::mat& parameters,
243 const size_t batchSize);
255 template<
typename Policy = PolicyType>
256 typename std::enable_if<std::is_same<Policy, WGANGP>::value,
258 Gradient(
const arma::mat& parameters,
261 const size_t batchSize);
274 void Forward(arma::mat&& input);
282 void Predict(arma::mat&& input,
303 const arma::mat&
Responses()
const {
return responses; }
313 template<
typename Archive>
314 void serialize(Archive& ar,
const unsigned int );
318 arma::mat predictors;
326 InitializationRuleType initializeRule;
340 size_t generatorUpdateStep;
346 double clippingParameter;
356 arma::mat currentInput;
358 arma::mat currentTarget;
368 arma::mat gradientDiscriminator;
370 arma::mat noiseGradientDiscriminator;
372 arma::mat normGradientDiscriminator;
376 arma::mat gradientGenerator;
385 #include "gan_impl.hpp" 386 #include "wgan_impl.hpp" 387 #include "wgangp_impl.hpp" std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient(const arma::mat ¶meters, const size_t i, GradType &gradient, const size_t batchSize)
EvaluateWithGradient function for the Standard GAN and DCGAN.
const Model & Discriminator() const
Return the discriminator of the GAN.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Model & Generator()
Modify the generator of the GAN.
WeightSizeVisitor returns the number of weights of the given module.
void Forward(arma::mat &&input)
This function does a forward pass through the GAN network.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
void Shuffle()
Shuffle the order of function visitation.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient(const arma::mat ¶meters, const size_t i, arma::mat &gradient, const size_t batchSize)
Gradient function for Standard GAN and DCGAN.
arma::mat & Parameters()
Modify the parameters of the network.
ResetVisitor executes the Reset() function.
double Train(OptimizerType &Optimizer)
Train function.
OutputParameterVisitor exposes the output parameter of the given module.
const arma::mat & Responses() const
Get the matrix of responses to the input data points.
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate(const arma::mat ¶meters, const size_t i, const size_t batchSize)
Evaluate function for the Standard GAN and DCGAN.
arma::mat & Predictors()
Modify the matrix of data points (predictors).
void Predict(arma::mat &&input, arma::mat &output)
This function predicts the output of the network on the given input.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
const Model & Generator() const
Return the generator of the GAN.
const arma::mat & Parameters() const
Return the parameters of the network.
DeltaVisitor exposes the delta parameter of the given module.
The implementation of the standard GAN module.
arma::mat & Responses()
Modify the matrix of responses to the input data points.
GAN(arma::mat &trainData, Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
Constructor for GAN class.
Model & Discriminator()
Modify the discriminator of the GAN.
const arma::mat & Predictors() const
Get the matrix of data points (predictors).