soft_margin_loss.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_ANN_LOSS_FUNCTION_SOFT_MARGIN_LOSS_HPP
17 #define MLPACK_ANN_LOSS_FUNCTION_SOFT_MARGIN_LOSS_HPP
18 
19 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
30 template <
31  typename InputDataType = arma::mat,
32  typename OutputDataType = arma::mat
33 >
35 {
36  public:
46  SoftMarginLoss(const bool reduction = true);
47 
54  template<typename InputType, typename TargetType>
55  typename InputType::elem_type Forward(const InputType& input,
56  const TargetType& target);
57 
65  template<typename InputType, typename TargetType, typename OutputType>
66  void Backward(const InputType& input,
67  const TargetType& target,
68  OutputType& output);
69 
71  OutputDataType& OutputParameter() const { return outputParameter; }
73  OutputDataType& OutputParameter() { return outputParameter; }
74 
76  bool Reduction() const { return reduction; }
78  bool& Reduction() { return reduction; }
79 
83  template<typename Archive>
84  void serialize(Archive& ar, const unsigned int /* version */);
85 
86  private:
88  OutputDataType outputParameter;
89 
91  bool reduction;
92 }; // class SoftMarginLoss
93 
94 } // namespace ann
95 } // namespace mlpack
96 
97 // include implementation.
98 #include "soft_margin_loss_impl.hpp"
99 
100 #endif
InputType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the Soft Margin Loss function.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter() const
Get the output parameter.
bool Reduction() const
Get the type of reduction used.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
bool & Reduction()
Modify the type of reduction used.
SoftMarginLoss(const bool reduction=true)
Create the SoftMarginLoss object.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
OutputDataType & OutputParameter()
Modify the output parameter.