12 #ifndef MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_HPP 13 #define MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_HPP 21 namespace optimization {
65 template<
typename UpdatePolicyType = SARAHUpdate>
91 const size_t batchSize = 32,
92 const size_t maxIterations = 1000,
93 const size_t innerIterations = 0,
94 const double tolerance = 1e-5,
95 const bool shuffle =
true,
96 const UpdatePolicyType& updatePolicy = UpdatePolicyType());
108 template<
typename DecomposableFunctionType>
109 double Optimize(DecomposableFunctionType&
function, arma::mat& iterate);
154 size_t maxIterations;
157 size_t innerIterations;
167 UpdatePolicyType updatePolicy;
186 #include "sarah_impl.hpp" size_t MaxIterations() const
Get the maximum number of iterations (0 indicates no limit).
double & Tolerance()
Modify the tolerance for termination.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & InnerIterations()
Modify the maximum number of iterations (0 indicates default n / b).
size_t & MaxIterations()
Modify the maximum number of iterations (0 indicates no limit).
bool & Shuffle()
Modify whether or not the individual functions are shuffled.
bool Shuffle() const
Get whether or not the individual functions are shuffled.
UpdatePolicyType & UpdatePolicy()
Modify the update policy.
StochAstic Recusive gRadient algoritHm (SARAH).
size_t InnerIterations() const
Get the maximum number of iterations (0 indicates default n / b).
double Optimize(DecomposableFunctionType &function, arma::mat &iterate)
Optimize the given function using SARAH.
double Tolerance() const
Get the tolerance for termination.
SARAHType(const double stepSize=0.01, const size_t batchSize=32, const size_t maxIterations=1000, const size_t innerIterations=0, const double tolerance=1e-5, const bool shuffle=true, const UpdatePolicyType &updatePolicy=UpdatePolicyType())
Construct the SARAH optimizer with the given function and parameters.
const UpdatePolicyType & UpdatePolicy() const
Get the update policy.
size_t & BatchSize()
Modify the batch size.
double & StepSize()
Modify the step size.
double StepSize() const
Get the step size.
size_t BatchSize() const
Get the batch size.