reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class Reparametrization
57 {
58  public:
61 
70  Reparametrization(const size_t latentSize,
71  const bool stochastic = true,
72  const bool includeKl = true,
73  const double beta = 1);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& input,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
100  OutputDataType const& OutputParameter() const { return outputParameter; }
102  OutputDataType& OutputParameter() { return outputParameter; }
103 
105  OutputDataType const& Delta() const { return delta; }
107  OutputDataType& Delta() { return delta; }
108 
110  size_t const& OutputSize() const { return latentSize; }
112  size_t& OutputSize() { return latentSize; }
113 
115  double Loss()
116  {
117  if (!includeKl)
118  return 0;
119 
120  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
121  - arma::pow(mean, 2) + 1) / mean.n_cols;
122  }
123 
125  bool Stochastic() const { return stochastic; }
126 
128  bool IncludeKL() const { return includeKl; }
129 
131  double Beta() const { return beta; }
132 
136  template<typename Archive>
137  void serialize(Archive& ar, const unsigned int /* version */);
138 
139  private:
141  size_t latentSize;
142 
144  bool stochastic;
145 
147  bool includeKl;
148 
150  double beta;
151 
153  OutputDataType delta;
154 
156  OutputDataType gaussianSample;
157 
159  OutputDataType mean;
160 
163  OutputDataType preStdDev;
164 
166  OutputDataType stdDev;
167 
169  OutputDataType outputParameter;
170 }; // class Reparametrization
171 
172 } // namespace ann
173 } // namespace mlpack
174 
175 // Include implementation.
176 #include "reparametrization_impl.hpp"
177 
178 #endif
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
bool IncludeKL() const
Get the value of the includeKl parameter.
bool Stochastic() const
Get the value of the stochastic parameter.
double Loss()
Get the KL divergence with standard normal.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
size_t const & OutputSize() const
Get the output size.
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t & OutputSize()
Modify the output size.
OutputDataType const & Delta() const
Get the delta.
Reparametrization()
Create the Reparametrization object.
double Beta() const
Get the value of the beta hyperparameter.