batch_norm.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace ann {
21 
53 template <
54  typename InputDataType = arma::mat,
55  typename OutputDataType = arma::mat
56 >
57 class BatchNorm
58 {
59  public:
61  BatchNorm();
62 
69  BatchNorm(const size_t size, const double eps = 1e-8);
70 
74  void Reset();
75 
84  template<typename eT>
85  void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);
86 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>&& input,
96  arma::Mat<eT>&& gy,
97  arma::Mat<eT>&& g);
98 
106  template<typename eT>
107  void Gradient(const arma::Mat<eT>&& input,
108  arma::Mat<eT>&& error,
109  arma::Mat<eT>&& gradient);
110 
112  OutputDataType const& Parameters() const { return weights; }
114  OutputDataType& Parameters() { return weights; }
115 
117  InputDataType const& InputParameter() const { return inputParameter; }
119  InputDataType& InputParameter() { return inputParameter; }
120 
122  OutputDataType const& OutputParameter() const { return outputParameter; }
124  OutputDataType& OutputParameter() { return outputParameter; }
125 
127  OutputDataType const& Delta() const { return delta; }
129  OutputDataType& Delta() { return delta; }
130 
132  OutputDataType const& Gradient() const { return gradient; }
134  OutputDataType& Gradient() { return gradient; }
135 
137  bool Deterministic() const { return deterministic; }
139  bool& Deterministic() { return deterministic; }
140 
142  OutputDataType TrainingMean() { return stats.mean(); }
143 
145  OutputDataType TrainingVariance() { return stats.var(1); }
146 
150  template<typename Archive>
151  void serialize(Archive& ar, const unsigned int /* version */);
152 
153  private:
155  size_t size;
156 
158  double eps;
159 
161  OutputDataType gamma;
162 
164  OutputDataType beta;
165 
167  OutputDataType weights;
168 
173  bool deterministic;
174 
176  OutputDataType mean;
177 
179  OutputDataType variance;
180 
182  arma::running_stat_vec<arma::colvec> stats;
183 
185  OutputDataType gradient;
186 
188  OutputDataType delta;
189 
191  InputDataType inputParameter;
192 
194  OutputDataType outputParameter;
195 
197  OutputDataType normalized;
198 }; // class BatchNorm
199 
200 } // namespace ann
201 } // namespace mlpack
202 
203 // Include the implementation.
204 #include "batch_norm_impl.hpp"
205 
206 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: batch_norm.hpp:134
InputDataType const & InputParameter() const
Get the input parameter.
Definition: batch_norm.hpp:117
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
void Forward(const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Forward pass of the Batch Normalization layer.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Backward pass through the layer.
OutputDataType & Delta()
Modify the delta.
Definition: batch_norm.hpp:129
OutputDataType TrainingMean()
Get the mean over the training data.
Definition: batch_norm.hpp:142
InputDataType & InputParameter()
Modify the input parameter.
Definition: batch_norm.hpp:119
bool Deterministic() const
Get the value of deterministic parameter.
Definition: batch_norm.hpp:137
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: batch_norm.hpp:122
void Reset()
Reset the layer parameters.
OutputDataType & Parameters()
Modify the parameters.
Definition: batch_norm.hpp:114
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: batch_norm.hpp:139
BatchNorm()
Create the BatchNorm object.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: batch_norm.hpp:124
OutputDataType const & Parameters() const
Get the parameters.
Definition: batch_norm.hpp:112
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:132
Declaration of the Batch Normalization layer class.
Definition: batch_norm.hpp:57
OutputDataType TrainingVariance()
Get the variance over the training data.
Definition: batch_norm.hpp:145
OutputDataType const & Delta() const
Get the delta.
Definition: batch_norm.hpp:127