batch_norm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class BatchNorm
57 {
58  public:
60  BatchNorm();
61 
68  BatchNorm(const size_t size, const double eps = 1e-8);
69 
73  void Reset();
74 
83  template<typename eT>
84  void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);
85 
93  template<typename eT>
94  void Backward(const arma::Mat<eT>&& input,
95  arma::Mat<eT>&& gy,
96  arma::Mat<eT>&& g);
97 
105  template<typename eT>
106  void Gradient(const arma::Mat<eT>&& input,
107  arma::Mat<eT>&& error,
108  arma::Mat<eT>&& gradient);
109 
111  OutputDataType const& Parameters() const { return weights; }
113  OutputDataType& Parameters() { return weights; }
114 
116  InputDataType const& InputParameter() const { return inputParameter; }
118  InputDataType& InputParameter() { return inputParameter; }
119 
121  OutputDataType const& OutputParameter() const { return outputParameter; }
123  OutputDataType& OutputParameter() { return outputParameter; }
124 
126  OutputDataType const& Delta() const { return delta; }
128  OutputDataType& Delta() { return delta; }
129 
131  OutputDataType const& Gradient() const { return gradient; }
133  OutputDataType& Gradient() { return gradient; }
134 
136  bool Deterministic() const { return deterministic; }
138  bool& Deterministic() { return deterministic; }
139 
141  OutputDataType TrainingMean() { return runningMean; }
142 
144  OutputDataType TrainingVariance() { return runningVariance / count; }
145 
149  template<typename Archive>
150  void serialize(Archive& ar, const unsigned int /* version */);
151 
152  private:
154  size_t size;
155 
157  double eps;
158 
160  bool loading;
161 
163  OutputDataType gamma;
164 
166  OutputDataType beta;
167 
169  OutputDataType weights;
170 
175  bool deterministic;
176 
178  size_t count;
179 
181  OutputDataType mean;
182 
184  OutputDataType variance;
185 
187  OutputDataType runningMean;
188 
190  OutputDataType runningVariance;
191 
193  OutputDataType gradient;
194 
196  OutputDataType delta;
197 
199  InputDataType inputParameter;
200 
202  OutputDataType outputParameter;
203 
205  OutputDataType normalized;
206 }; // class BatchNorm
207 
208 } // namespace ann
209 } // namespace mlpack
210 
211 // Include the implementation.
212 #include "batch_norm_impl.hpp"
213 
214 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: batch_norm.hpp:133
InputDataType const & InputParameter() const
Get the input parameter.
Definition: batch_norm.hpp:116
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
void Forward(const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Forward pass of the Batch Normalization layer.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Backward pass through the layer.
OutputDataType & Delta()
Modify the delta.
Definition: batch_norm.hpp:128
OutputDataType TrainingMean()
Get the mean over the training data.
Definition: batch_norm.hpp:141
InputDataType & InputParameter()
Modify the input parameter.
Definition: batch_norm.hpp:118
bool Deterministic() const
Get the value of deterministic parameter.
Definition: batch_norm.hpp:136
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: batch_norm.hpp:121
void Reset()
Reset the layer parameters.
OutputDataType & Parameters()
Modify the parameters.
Definition: batch_norm.hpp:113
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: batch_norm.hpp:138
BatchNorm()
Create the BatchNorm object.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: batch_norm.hpp:123
OutputDataType const & Parameters() const
Get the parameters.
Definition: batch_norm.hpp:111
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:131
Declaration of the Batch Normalization layer class.
Definition: batch_norm.hpp:56
OutputDataType TrainingVariance()
Get the variance over the training data.
Definition: batch_norm.hpp:144
OutputDataType const & Delta() const
Get the delta.
Definition: batch_norm.hpp:126