backtracking_line_search.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_CORE_OPTIMIZERS_BIGBATCH_SGD_BACKTRACKING_LINE_SEARCH_HPP
16 #define MLPACK_CORE_OPTIMIZERS_BIGBATCH_SGD_BACKTRACKING_LINE_SEARCH_HPP
17 
18 namespace mlpack {
19 namespace optimization {
20 
41 {
42  public:
51  BacktrackingLineSearch(const double searchParameter = 0.1) :
52  searchParameter(searchParameter)
53  { /* Nothing to do here. */ }
54 
70  template<typename DecomposableFunctionType>
71  void Update(DecomposableFunctionType& function,
72  double& stepSize,
73  arma::mat& iterate,
74  const arma::mat& gradient,
75  const double gradientNorm,
76  const double /* sampleVariance */,
77  const size_t offset,
78  const size_t /* batchSize */,
79  const size_t backtrackingBatchSize,
80  const bool reset)
81  {
82  if (reset)
83  stepSize *= 2;
84 
85  double overallObjective = function.Evaluate(iterate, offset,
86  backtrackingBatchSize);
87 
88  arma::mat iterateUpdate = iterate - (stepSize * gradient);
89  double overallObjectiveUpdate = function.Evaluate(iterateUpdate,
90  offset, backtrackingBatchSize);
91 
92  while (overallObjectiveUpdate >
93  (overallObjective + searchParameter * stepSize * gradientNorm))
94  {
95  stepSize /= 2;
96 
97  iterateUpdate = iterate - (stepSize * gradient);
98  overallObjectiveUpdate = function.Evaluate(iterateUpdate,
99  offset, backtrackingBatchSize);
100  }
101  }
102 
103  private:
105  double searchParameter;
106 };
107 
108 } // namespace optimization
109 } // namespace mlpack
110 
111 #endif // MLPACK_CORE_OPTIMIZERS_BIGBATCH_SGD_BACKTRACKING_LINE_SEARCH_HPP
.hpp
Definition: add_to_po.hpp:21
void Update(DecomposableFunctionType &function, double &stepSize, arma::mat &iterate, const arma::mat &gradient, const double gradientNorm, const double, const size_t offset, const size_t, const size_t backtrackingBatchSize, const bool reset)
This function is called in each iteration.
Definition of the backtracking line search algorithm based on the Armijo–Goldstein condition to dete...
BacktrackingLineSearch(const double searchParameter=0.1)
Construct the BacktrackingLineSearch object with the given function and parameters.