gradient_clipping.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP
13 #define MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace optimization {
19 
28 template<typename UpdatePolicyType>
30 {
31  public:
40  GradientClipping(const double minGradient,
41  const double maxGradient,
42  UpdatePolicyType& updatePolicy) :
43  minGradient(minGradient),
44  maxGradient(maxGradient),
45  updatePolicy(updatePolicy)
46  {
47  // Nothing to do here
48  }
49 
58  void Initialize(const size_t rows, const size_t cols)
59  {
60  updatePolicy.Initialize(rows, cols);
61  }
62 
71  void Update(arma::mat& iterate,
72  const double stepSize,
73  const arma::mat& gradient)
74  {
75  // First, clip the gradient.
76  arma::mat clippedGradient = arma::clamp(gradient, minGradient, maxGradient);
77  // And only then do the update.
78  updatePolicy.Update(iterate, stepSize, clippedGradient);
79  }
80 
82  UpdatePolicyType& UpdatePolicy() const { return updatePolicy; }
84  UpdatePolicyType& UpdatePolicy() { return updatePolicy; }
85 
87  double MinGradient() const { return minGradient; }
89  double& MinGradient() { return minGradient; }
90 
92  double MaxGradient() const { return maxGradient; }
94  double& MaxGradient() { return maxGradient; }
95 
96  private:
98  double minGradient;
99 
101  double maxGradient;
102 
104  UpdatePolicyType updatePolicy;
105 };
106 
107 } // namespace optimization
108 } // namespace mlpack
109 
110 #endif
void Initialize(const size_t rows, const size_t cols)
The Initialize method is called by SGD Optimizer method before the start of the iteration update proc...
.hpp
Definition: add_to_po.hpp:21
double & MaxGradient()
Modify the maximum gradient value.
GradientClipping(const double minGradient, const double maxGradient, UpdatePolicyType &updatePolicy)
Constructor for creating a GradientClipping instance.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Interface for wrapping around update policies (e.g., VanillaUpdate) and feeding a clipped gradient to...
double MinGradient() const
Get the minimum gradient value.
double MaxGradient() const
Get the maximum gradient value.
UpdatePolicyType & UpdatePolicy()
Modify the update policy.
double & MinGradient()
Modify the minimum gradient value.
UpdatePolicyType & UpdatePolicy() const
Get the update policy.
void Update(arma::mat &iterate, const double stepSize, const arma::mat &gradient)
Update step.