13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP 14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP 49 class ActivationFunction = LogisticFunction,
50 typename InputDataType = arma::mat,
51 typename OutputDataType = arma::mat
71 template<
typename InputType,
typename OutputType>
72 void Forward(
const InputType&& input, OutputType&& output)
74 ActivationFunction::Fn(input, output);
91 arma::Mat<eT> derivative;
92 ActivationFunction::Deriv(input, derivative);
102 OutputDataType
const&
Delta()
const {
return delta; }
104 OutputDataType&
Delta() {
return delta; }
109 template<
typename Archive>
117 OutputDataType delta;
120 OutputDataType outputParameter;
130 typename InputDataType = arma::mat,
131 typename OutputDataType = arma::mat
134 ActivationFunction, InputDataType, OutputDataType>;
141 typename InputDataType = arma::mat,
142 typename OutputDataType = arma::mat
145 ActivationFunction, InputDataType, OutputDataType>;
152 typename InputDataType = arma::mat,
153 typename OutputDataType = arma::mat
156 ActivationFunction, InputDataType, OutputDataType>;
163 typename InputDataType = arma::mat,
164 typename OutputDataType = arma::mat
167 ActivationFunction, InputDataType, OutputDataType>;
174 typename InputDataType = arma::mat,
175 typename OutputDataType = arma::mat
178 ActivationFunction, InputDataType, OutputDataType>;
185 typename InputDataType = arma::mat,
186 typename OutputDataType = arma::mat
189 ActivationFunction, InputDataType, OutputDataType>;
196 typename InputDataType = arma::mat,
197 typename OutputDataType = arma::mat
200 ActivationFunction, InputDataType, OutputDataType>;
206 class ActivationFunction = MishFunction,
207 typename InputDataType = arma::mat,
208 typename OutputDataType = arma::mat
211 ActivationFunction, InputDataType, OutputDataType>;
218 typename InputDataType = arma::mat,
219 typename OutputDataType = arma::mat
222 ActivationFunction, InputDataType, OutputDataType>;
229 typename InputDataType = arma::mat,
230 typename OutputDataType = arma::mat
233 ActivationFunction, InputDataType, OutputDataType>;
The identity function, defined by.
OutputDataType & OutputParameter()
Modify the output parameter.
BaseLayer()
Create the BaseLayer object.
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
void serialize(Archive &, const unsigned int)
Serialize the layer.
The tanh function, defined by.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType const & OutputParameter() const
Get the output parameter.
OutputDataType const & Delta() const
Get the delta.
Implementation of the base layer.
The logistic function, defined by.
void Forward(const InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...