kl_divergence.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
41 template <
42  typename InputDataType = arma::mat,
43  typename OutputDataType = arma::mat
44 >
46 {
47  public:
54  KLDivergence(const bool takeMean = false);
55 
62  template<typename InputType, typename TargetType>
63  typename InputType::elem_type Forward(const InputType& input,
64  const TargetType& target);
65 
73  template<typename InputType, typename TargetType, typename OutputType>
74  void Backward(const InputType& input,
75  const TargetType& target,
76  OutputType& output);
77 
79  OutputDataType& OutputParameter() const { return outputParameter; }
81  OutputDataType& OutputParameter() { return outputParameter; }
82 
84  bool TakeMean() const { return takeMean; }
86  bool& TakeMean() { return takeMean; }
87 
91  template<typename Archive>
92  void serialize(Archive& ar, const unsigned int /* version */);
93 
94  private:
96  OutputDataType outputParameter;
97 
99  bool takeMean;
100 }; // class KLDivergence
101 
102 } // namespace ann
103 } // namespace mlpack
104 
105 // include implementation
106 #include "kl_divergence_impl.hpp"
107 
108 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool TakeMean() const
Get the value of takeMean.
OutputDataType & OutputParameter()
Modify the output parameter.
InputType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the Kullback–Leibler divergence error function.
bool & TakeMean()
Modify the value of takeMean.
KLDivergence(const bool takeMean=false)
Create the Kullback–Leibler Divergence object with the specified parameters.
void serialize(Archive &ar, const unsigned int)
Serialize the loss function.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
The Kullback–Leibler divergence is often used for continuous distributions (direct regression)...
OutputDataType & OutputParameter() const
Get the output parameter.