kl_divergence.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
41 template <
42  typename InputDataType = arma::mat,
43  typename OutputDataType = arma::mat
44 >
46 {
47  public:
54  KLDivergence(const bool takeMean = false);
55 
62  template<typename InputType, typename TargetType>
63  double Forward(const InputType&& input, const TargetType&& target);
64 
72  template<typename InputType, typename TargetType, typename OutputType>
73  void Backward(const InputType&& input,
74  const TargetType&& target,
75  OutputType&& output);
76 
78  InputDataType& InputParameter() const { return inputParameter; }
80  InputDataType& InputParameter() { return inputParameter; }
81 
83  OutputDataType& OutputParameter() const { return outputParameter; }
85  OutputDataType& OutputParameter() { return outputParameter; }
86 
88  OutputDataType& Delta() const { return delta; }
90  OutputDataType& Delta() { return delta; }
91 
93  bool TakeMean() const { return takeMean; }
95  bool& TakeMean() { return takeMean; }
96 
100  template<typename Archive>
101  void serialize(Archive& ar, const unsigned int /* version */);
102 
103  private:
105  OutputDataType delta;
106 
108  InputDataType inputParameter;
109 
111  OutputDataType outputParameter;
112 
114  bool takeMean;
115 }; // class KLDivergence
116 
117 } // namespace ann
118 } // namespace mlpack
119 
120 // include implementation
121 #include "kl_divergence_impl.hpp"
122 
123 #endif
void Backward(const InputType &&input, const TargetType &&target, OutputType &&output)
Ordinary feed backward pass of a neural network.
double Forward(const InputType &&input, const TargetType &&target)
Computes the Kullback–Leibler divergence error function.
.hpp
Definition: add_to_po.hpp:21
OutputDataType & Delta()
Modify the delta.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool TakeMean() const
Get the value of takeMean.
OutputDataType & OutputParameter()
Modify the output parameter.
bool & TakeMean()
Modify the value of takeMean.
InputDataType & InputParameter() const
Get the input parameter.
KLDivergence(const bool takeMean=false)
Create the Kullback–Leibler Divergence object with the specified parameters.
OutputDataType & Delta() const
Get the delta.
void serialize(Archive &ar, const unsigned int)
Serialize the loss function.
InputDataType & InputParameter()
Modify the input parameter.
The Kullback–Leibler divergence is often used for continuous distributions (direct regression)...
OutputDataType & OutputParameter() const
Get the output parameter.