snapshot_ensembles.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_OPTIMIZERS_SGDR_SNAPSHOT_ENSEMBLES_HPP
15 #define MLPACK_CORE_OPTIMIZERS_SGDR_SNAPSHOT_ENSEMBLES_HPP
16 
17 namespace mlpack {
18 namespace optimization {
19 
41 {
42  public:
58  SnapshotEnsembles(const size_t epochRestart,
59  const double multFactor,
60  const double stepSize,
61  const size_t maxIterations,
62  const size_t snapshots) :
63  epochRestart(epochRestart),
64  multFactor(multFactor),
65  constStepSize(stepSize),
66  nextRestart(epochRestart),
67  batchRestart(0),
68  epoch(0)
69  {
70  snapshotEpochs = 0;
71  for (size_t i = 0, er = epochRestart, nr = nextRestart;
72  i < maxIterations; ++i)
73  {
74  if (i > nr)
75  {
76  er *= multFactor;
77  nr += er;
78  snapshotEpochs++;
79  }
80  }
81 
82  snapshotEpochs = epochRestart * std::pow(multFactor,
83  snapshotEpochs - snapshots + 1);
84  }
85 
93  void Update(arma::mat& iterate,
94  double& stepSize,
95  const arma::mat& /* gradient */)
96  {
97  // Time to adjust the step size.
98  if (epoch >= epochRestart)
99  {
100  // n_t = n_min^i + 0.5(n_max^i - n_min^i)(1 + cos(T_cur/T_i * pi)).
101  stepSize = 0.5 * constStepSize * (1 + cos((batchRestart / epochBatches)
102  * M_PI));
103 
104  // Keep track of the number of batches since the last restart.
105  batchRestart++;
106  }
107 
108  // Time to restart.
109  if (epoch > nextRestart)
110  {
111  batchRestart = 0;
112 
113  // Adjust the period of restarts.
114  epochRestart *= multFactor;
115 
116  // Create a new snapshot.
117  if (epochRestart >= snapshotEpochs)
118  {
119  snapshots.push_back(iterate);
120  }
121 
122  // Update the time for the next restart.
123  nextRestart += epochRestart;
124  }
125 
126  epoch++;
127  }
128 
130  double StepSize() const { return constStepSize; }
132  double& StepSize() { return constStepSize; }
133 
135  double EpochBatches() const { return epochBatches; }
137  double& EpochBatches() { return epochBatches; }
138 
140  std::vector<arma::mat> Snapshots() const { return snapshots; }
142  std::vector<arma::mat>& Snapshots() { return snapshots; }
143 
144  private:
146  size_t epochRestart;
147 
149  double multFactor;
150 
152  double constStepSize;
153 
155  size_t nextRestart;
156 
158  size_t batchRestart;
159 
161  double epochBatches;
162 
164  size_t epoch;
165 
167  size_t snapshotEpochs;
168 
170  std::vector<arma::mat> snapshots;
171 };
172 
173 } // namespace optimization
174 } // namespace mlpack
175 
176 #endif // MLPACK_CORE_OPTIMIZERS_SGDR_CYCLICAL_DECAY_HPP
std::vector< arma::mat > & Snapshots()
Modify the snapshots.
.hpp
Definition: add_to_po.hpp:21
Simulate a new warm-started run/restart once a number of epochs are performed.
double & StepSize()
Modify the step size.
#define M_PI
Definition: prereqs.hpp:39
std::vector< arma::mat > Snapshots() const
Get the snapshots.
double EpochBatches() const
Get the restart fraction.
double StepSize() const
Get the step size.
SnapshotEnsembles(const size_t epochRestart, const double multFactor, const double stepSize, const size_t maxIterations, const size_t snapshots)
Construct the CyclicalDecay technique a restart method, where the step size decays after each batch a...
double & EpochBatches()
Modify the restart fraction.
void Update(arma::mat &iterate, double &stepSize, const arma::mat &)
This function is called in each iteration after the policy update.