dice_loss.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
44 template <
45  typename InputDataType = arma::mat,
46  typename OutputDataType = arma::mat
47 >
48 class DiceLoss
49 {
50  public:
56  DiceLoss(const double smooth = 1);
57 
64  template<typename InputType, typename TargetType>
65  double Forward(const InputType&& input, const TargetType&& target);
66 
74  template<typename InputType, typename TargetType, typename OutputType>
75  void Backward(const InputType&& input,
76  const TargetType&& target,
77  OutputType&& output);
78 
80  OutputDataType& OutputParameter() const { return outputParameter; }
82  OutputDataType& OutputParameter() { return outputParameter; }
83 
85  double Smooth() const { return smooth; }
87  double& Smooth() { return smooth; }
88 
92  template<typename Archive>
93  void serialize(Archive& ar, const unsigned int /* version */);
94 
95  private:
97  OutputDataType outputParameter;
98 
100  double smooth;
101 }; // class DiceLoss
102 
103 } // namespace ann
104 } // namespace mlpack
105 
106 // Include implementation.
107 #include "dice_loss_impl.hpp"
108 
109 #endif
The dice loss performance function measures the network&#39;s performance according to the dice coefficie...
Definition: dice_loss.hpp:48
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
double & Smooth()
Modify the smooth.
Definition: dice_loss.hpp:87
void Backward(const InputType &&input, const TargetType &&target, OutputType &&output)
Ordinary feed backward pass of a neural network.
OutputDataType & OutputParameter() const
Get the output parameter.
Definition: dice_loss.hpp:80
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
double Smooth() const
Get the smooth.
Definition: dice_loss.hpp:85
DiceLoss(const double smooth=1)
Create the DiceLoss object.
double Forward(const InputType &&input, const TargetType &&target)
Computes the dice loss function.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: dice_loss.hpp:82