base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
27 
28 namespace mlpack {
29 namespace ann {
30 
48 template <
49  class ActivationFunction = LogisticFunction,
50  typename InputDataType = arma::mat,
51  typename OutputDataType = arma::mat
52 >
53 class BaseLayer
54 {
55  public:
60  {
61  // Nothing to do here.
62  }
63 
71  template<typename InputType, typename OutputType>
72  void Forward(const InputType&& input, OutputType&& output)
73  {
74  ActivationFunction::Fn(input, output);
75  }
76 
86  template<typename eT>
87  void Backward(const arma::Mat<eT>&& input,
88  arma::Mat<eT>&& gy,
89  arma::Mat<eT>&& g)
90  {
91  arma::Mat<eT> derivative;
92  ActivationFunction::Deriv(input, derivative);
93  g = gy % derivative;
94  }
95 
97  OutputDataType const& OutputParameter() const { return outputParameter; }
99  OutputDataType& OutputParameter() { return outputParameter; }
100 
102  OutputDataType const& Delta() const { return delta; }
104  OutputDataType& Delta() { return delta; }
105 
109  template<typename Archive>
110  void serialize(Archive& /* ar */, const unsigned int /* version */)
111  {
112  /* Nothing to do here */
113  }
114 
115  private:
117  OutputDataType delta;
118 
120  OutputDataType outputParameter;
121 }; // class BaseLayer
122 
123 // Convenience typedefs.
124 
128 template <
129  class ActivationFunction = LogisticFunction,
130  typename InputDataType = arma::mat,
131  typename OutputDataType = arma::mat
132 >
133 using SigmoidLayer = BaseLayer<
134  ActivationFunction, InputDataType, OutputDataType>;
135 
139 template <
140  class ActivationFunction = IdentityFunction,
141  typename InputDataType = arma::mat,
142  typename OutputDataType = arma::mat
143 >
144 using IdentityLayer = BaseLayer<
145  ActivationFunction, InputDataType, OutputDataType>;
146 
150 template <
151  class ActivationFunction = RectifierFunction,
152  typename InputDataType = arma::mat,
153  typename OutputDataType = arma::mat
154 >
155 using ReLULayer = BaseLayer<
156  ActivationFunction, InputDataType, OutputDataType>;
157 
161 template <
162  class ActivationFunction = TanhFunction,
163  typename InputDataType = arma::mat,
164  typename OutputDataType = arma::mat
165 >
166 using TanHLayer = BaseLayer<
167  ActivationFunction, InputDataType, OutputDataType>;
168 
172 template <
173  class ActivationFunction = SoftplusFunction,
174  typename InputDataType = arma::mat,
175  typename OutputDataType = arma::mat
176 >
177 using SoftPlusLayer = BaseLayer<
178  ActivationFunction, InputDataType, OutputDataType>;
179 
183 template <
184  class ActivationFunction = HardSigmoidFunction,
185  typename InputDataType = arma::mat,
186  typename OutputDataType = arma::mat
187 >
189  ActivationFunction, InputDataType, OutputDataType>;
190 
194 template <
195  class ActivationFunction = SwishFunction,
196  typename InputDataType = arma::mat,
197  typename OutputDataType = arma::mat
198 >
200  ActivationFunction, InputDataType, OutputDataType>;
201 
205 template <
206  class ActivationFunction = MishFunction,
207  typename InputDataType = arma::mat,
208  typename OutputDataType = arma::mat
209 >
211  ActivationFunction, InputDataType, OutputDataType>;
212 
216 template <
217  class ActivationFunction = LiSHTFunction,
218  typename InputDataType = arma::mat,
219  typename OutputDataType = arma::mat
220 >
222  ActivationFunction, InputDataType, OutputDataType>;
223 
227 template <
228  class ActivationFunction = GELUFunction,
229  typename InputDataType = arma::mat,
230  typename OutputDataType = arma::mat
231 >
233  ActivationFunction, InputDataType, OutputDataType>;
234 
235 } // namespace ann
236 } // namespace mlpack
237 
238 #endif
The identity function, defined by.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:99
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:59
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:104
void serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:110
The tanh function, defined by.
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:97
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:102
Implementation of the base layer.
Definition: base_layer.hpp:53
The logistic function, defined by.
void Forward(const InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:72
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:87