reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
51 template <
52  typename InputDataType = arma::mat,
53  typename OutputDataType = arma::mat
54 >
55 class Reparametrization
56 {
57  public:
60 
69  Reparametrization(const size_t latentSize,
70  const bool stochastic = true,
71  const bool includeKl = true,
72  const double beta = 1);
73 
81  template<typename eT>
82  void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);
83 
93  template<typename eT>
94  void Backward(const arma::Mat<eT>&& input,
95  arma::Mat<eT>&& gy,
96  arma::Mat<eT>&& g);
97 
99  OutputDataType const& OutputParameter() const { return outputParameter; }
101  OutputDataType& OutputParameter() { return outputParameter; }
102 
104  OutputDataType const& Delta() const { return delta; }
106  OutputDataType& Delta() { return delta; }
107 
109  size_t const& OutputSize() const { return latentSize; }
111  size_t& OutputSize() { return latentSize; }
112 
114  double Loss()
115  {
116  if (!includeKl)
117  return 0;
118 
119  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
120  - arma::pow(mean, 2) + 1) / mean.n_cols;
121  }
122 
126  template<typename Archive>
127  void serialize(Archive& ar, const unsigned int /* version */);
128 
129  private:
131  size_t latentSize;
132 
134  bool stochastic;
135 
137  bool includeKl;
138 
140  double beta;
141 
143  OutputDataType delta;
144 
146  OutputDataType gaussianSample;
147 
149  OutputDataType mean;
150 
153  OutputDataType preStdDev;
154 
156  OutputDataType stdDev;
157 
159  OutputDataType outputParameter;
160 }; // class Reparametrization
161 
162 } // namespace ann
163 } // namespace mlpack
164 
165 // Include implementation.
166 #include "reparametrization_impl.hpp"
167 
168 #endif
void Forward(const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType & Delta()
Modify the delta.
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Loss()
Get the KL divergence with standard normal.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & OutputParameter()
Modify the output parameter.
size_t const & OutputSize() const
Get the output size.
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t & OutputSize()
Modify the output size.
OutputDataType const & Delta() const
Get the delta.
Reparametrization()
Create the Reparametrization object.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...