momentum_update.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_OPTIMIZERS_SGD_MOMENTUM_UPDATE_HPP
13 #define MLPACK_CORE_OPTIMIZERS_SGD_MOMENTUM_UPDATE_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace optimization {
19 
64 class MomentumUpdate
65 {
66  public:
72  MomentumUpdate(const double momentum = 0.5) : momentum(momentum)
73  { /* Do nothing. */ };
74 
84  void Initialize(const size_t rows, const size_t cols)
85  {
86  // Initialize am empty velocity matrix.
87  velocity = arma::zeros<arma::mat>(rows, cols);
88  }
89 
99  void Update(arma::mat& iterate,
100  const double stepSize,
101  const arma::mat& gradient)
102  {
103  velocity = momentum * velocity - stepSize * gradient;
104  iterate += velocity;
105  }
106 
107  private:
108  // The momentum hyperparamter
109  double momentum;
110  // The velocity matrix.
111  arma::mat velocity;
112 };
113 
114 } // namespace optimization
115 } // namespace mlpack
116 
117 #endif
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.