poisson_nll_loss.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
32 template <
33  typename InputDataType = arma::mat,
34  typename OutputDataType = arma::mat
35 >
37 {
38  public:
50  PoissonNLLLoss(const bool logInput = true,
51  const bool full = false,
52  const typename InputDataType::elem_type eps = 1e-08,
53  const bool mean = true);
54 
62  template<typename InputType, typename TargetType>
63  typename InputDataType::elem_type Forward(const InputType& input,
64  const TargetType& target);
65 
77  template<typename InputType, typename TargetType, typename OutputType>
78  void Backward(const InputType& input,
79  const TargetType& target,
80  OutputType& output);
81 
83  InputDataType& InputParameter() const { return inputParameter; }
85  InputDataType& InputParameter() { return inputParameter; }
86 
88  OutputDataType& OutputParameter() const { return outputParameter; }
90  OutputDataType& OutputParameter() { return outputParameter; }
91 
94  bool LogInput() const { return logInput; }
97  bool& LogInput() { return logInput; }
98 
101  bool Full() const { return full; }
104  bool& Full() { return full; }
105 
108  typename InputDataType::elem_type Eps() const { return eps; }
111  typename InputDataType::elem_type& Eps() { return eps; }
112 
115  bool Mean() const { return mean; }
118  bool& Mean() { return mean; }
119 
123  template<typename Archive>
124  void serialize(Archive& ar, const unsigned int /* version */);
125 
126  private:
128  template<typename eT>
129  void CheckProbs(const arma::Mat<eT>& probs)
130  {
131  for (size_t i = 0; i < probs.size(); ++i)
132  {
133  if (probs[i] > 1.0 || probs[i] < 0.0)
134  Log::Fatal << "Probabilities cannot be greater than 1 "
135  << "or smaller than 0." << std::endl;
136  }
137  }
138 
140  InputDataType inputParameter;
141 
143  OutputDataType outputParameter;
144 
146  bool logInput;
147 
149  // approximation term.
150  bool full;
151 
153  typename InputDataType::elem_type eps;
154 
156  bool mean;
157 }; // class PoissonNLLLoss
158 
159 } // namespace ann
160 } // namespace mlpack
161 
162 // Include implementation.
163 #include "poisson_nll_loss_impl.hpp"
164 
165 #endif
OutputDataType & OutputParameter()
Modify the output parameter.
bool & Full()
Modify the value of full.
Implementation of the Poisson negative log likelihood loss.
InputDataType::elem_type Eps() const
Get the value of eps.
Linear algebra utility functions, generally performed on matrices or vectors.
bool & Mean()
Modify the value of mean.
InputDataType & InputParameter() const
Get the input parameter.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool Full() const
Get the value of full.
OutputDataType & OutputParameter() const
Get the output parameter.
bool Mean() const
Get the value of mean.
bool LogInput() const
Get the value of logInput.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
InputDataType::elem_type & Eps()
Modify the value of eps.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
InputDataType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the Poisson negative log likelihood Loss.
PoissonNLLLoss(const bool logInput=true, const bool full=false, const typename InputDataType::elem_type eps=1e-08, const bool mean=true)
Create the PoissonNLLLoss object.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
bool & LogInput()
Modify the value of logInput.
InputDataType & InputParameter()
Modify the input parameter.