snapshot_sgdr.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_CORE_OPTIMIZERS_SGDR_SNAPSHOT_SGDR_HPP
16 #define MLPACK_CORE_OPTIMIZERS_SGDR_SNAPSHOT_SGDR_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
22 #include "snapshot_ensembles.hpp"
23 
24 namespace mlpack {
25 namespace optimization {
26 
60 template<typename UpdatePolicyType = MomentumUpdate>
62 {
63  public:
66 
87  SnapshotSGDR(const size_t epochRestart = 50,
88  const double multFactor = 2.0,
89  const size_t batchSize = 1000,
90  const double stepSize = 0.01,
91  const size_t maxIterations = 100000,
92  const double tolerance = 1e-5,
93  const bool shuffle = true,
94  const size_t snapshots = 5,
95  const bool accumulate = true,
96  const UpdatePolicyType& updatePolicy = UpdatePolicyType());
97 
107  template<typename DecomposableFunctionType>
108  double Optimize(DecomposableFunctionType& function, arma::mat& iterate);
109 
111  size_t BatchSize() const { return optimizer.BatchSize(); }
113  size_t& BatchSize() { return optimizer.BatchSize(); }
114 
116  double StepSize() const { return optimizer.StepSize(); }
118  double& StepSize() { return optimizer.StepSize(); }
119 
121  size_t MaxIterations() const { return optimizer.MaxIterations(); }
123  size_t& MaxIterations() { return optimizer.MaxIterations(); }
124 
126  double Tolerance() const { return optimizer.Tolerance(); }
128  double& Tolerance() { return optimizer.Tolerance(); }
129 
131  bool Shuffle() const { return optimizer.Shuffle(); }
133  bool& Shuffle() { return optimizer.Shuffle(); }
134 
136  std::vector<arma::mat> Snapshots() const
137  {
138  return optimizer.DecayPolicy().Snapshots();
139  }
141  std::vector<arma::mat>& Snapshots()
142  {
143  return optimizer.DecayPolicy().Snapshots();
144  }
145 
147  const UpdatePolicyType& UpdatePolicy() const
148  {
149  return optimizer.UpdatePolicy();
150  }
152  UpdatePolicyType& UpdatePolicy()
153  {
154  return optimizer.UpdatePolicy();
155  }
156 
157  private:
159  size_t batchSize;
160 
162  bool accumulate;
163 
165  OptimizerType optimizer;
166 };
167 
168 } // namespace optimization
169 } // namespace mlpack
170 
171 // Include implementation.
172 #include "snapshot_sgdr_impl.hpp"
173 
174 #endif
double Optimize(DecomposableFunctionType &function, arma::mat &iterate)
Optimize the given function using SGDR.
const DecayPolicyType & DecayPolicy() const
Get the step size decay policy.
Definition: sgd.hpp:171
UpdatePolicyType & UpdatePolicy()
Modify the update policy.
double StepSize() const
Get the step size.
bool Shuffle() const
Get whether or not the individual functions are shuffled.
Definition: sgd.hpp:154
const UpdatePolicyType & UpdatePolicy() const
Get the update policy.
Definition: sgd.hpp:166
SnapshotSGDR(const size_t epochRestart=50, const double multFactor=2.0, const size_t batchSize=1000, const double stepSize=0.01, const size_t maxIterations=100000, const double tolerance=1e-5, const bool shuffle=true, const size_t snapshots=5, const bool accumulate=true, const UpdatePolicyType &updatePolicy=UpdatePolicyType())
Construct the SnapshotSGDR optimizer with snapshot ensembles with the given function and parameters...
.hpp
Definition: add_to_po.hpp:21
double Tolerance() const
Get the tolerance for termination.
size_t BatchSize() const
Get the batch size.
Definition: sgd.hpp:139
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Tolerance() const
Get the tolerance for termination.
Definition: sgd.hpp:149
bool & Shuffle()
Modify whether or not the individual functions are shuffled.
This class is based on Mini-batch Stochastic Gradient Descent class and simulates a new warm-started ...
size_t & MaxIterations()
Modify the maximum number of iterations (0 indicates no limit).
std::vector< arma::mat > Snapshots() const
Get the snapshots.
size_t MaxIterations() const
Get the maximum number of iterations (0 indicates no limit).
Definition: sgd.hpp:144
double StepSize() const
Get the step size.
Definition: sgd.hpp:134
std::vector< arma::mat > & Snapshots()
Modify the snapshots.
double & Tolerance()
Modify the tolerance for termination.
double & StepSize()
Modify the step size.
const UpdatePolicyType & UpdatePolicy() const
Get the update policy.
bool Shuffle() const
Get whether or not the individual functions are shuffled.
size_t BatchSize() const
Get the batch size.
size_t MaxIterations() const
Get the maximum number of iterations (0 indicates no limit).
size_t & BatchSize()
Modify the batch size.
std::vector< arma::mat > Snapshots() const
Get the snapshots.