sarah_plus_update.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_PLUS_UPDATE_HPP
14 #define MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_PLUS_UPDATE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace optimization {
20 
25 {
26  public:
27  /*
28  * Construct the SARAH+ update policy.
29  *
30  * @param gamma Adaptive parameter for the inner loop.
31  */
32  SARAHPlusUpdate(const double gamma = 0.125) : gamma(gamma)
33  {
34  /* Nothing to do here. */
35  }
36 
49  bool Update(arma::mat& iterate,
50  arma::mat& v,
51  const arma::mat& gradient,
52  const arma::mat& gradient0,
53  const size_t batchSize,
54  const double stepSize,
55  const double vNorm)
56  {
57  v += (gradient - gradient0) / (double) batchSize;
58  iterate -= stepSize * v;
59 
60  if (arma::norm(v) <= gamma * vNorm)
61  return true;
62 
63  return false;
64  }
65 
66  private:
68  double gamma;
69 };
70 
71 } // namespace optimization
72 } // namespace mlpack
73 
74 #endif
bool Update(arma::mat &iterate, arma::mat &v, const arma::mat &gradient, const arma::mat &gradient0, const size_t batchSize, const double stepSize, const double vNorm)
Update step for SARAH+.
.hpp
Definition: add_to_po.hpp:21
SARAH+ provides an automatic and adaptive choice of the inner loop size.
The core includes that mlpack expects; standard C++ includes and Armadillo.
SARAHPlusUpdate(const double gamma=0.125)