13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP 14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP 30 typename EnvironmentType,
35 class OneStepSarsaWorker
53 const UpdaterType& updater,
54 const EnvironmentType& environment,
58 environment(environment),
60 deterministic(deterministic),
61 pending(config.UpdateInterval())
70 updater.Initialize(learningNetwork.Parameters().n_rows,
71 learningNetwork.Parameters().n_cols);
73 network = learningNetwork;
87 bool Step(NetworkType& learningNetwork,
88 NetworkType& targetNetwork,
94 if (action == ActionType::size)
97 arma::colvec actionValue;
98 network.Predict(state.Encode(), actionValue);
99 action = policy.Sample(actionValue, deterministic);
102 double reward = environment.Sample(state, action, nextState);
103 bool terminal = environment.IsTerminal(nextState);
104 arma::colvec actionValue;
105 network.Predict(nextState.Encode(), actionValue);
106 ActionType nextAction = policy.Sample(actionValue, deterministic);
108 episodeReturn += reward;
111 terminal = terminal || steps >= config.
StepLimit();
116 totalReward = episodeReturn;
119 network = learningNetwork;
130 pending[pendingIndex++] =
131 std::make_tuple(state, action, reward, nextState, nextAction);
136 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
137 learningNetwork.Parameters().n_cols, arma::fill::zeros);
138 for (
size_t i = 0; i < pending.size(); ++i)
143 arma::colvec actionValue;
146 targetNetwork.Predict(
147 std::get<3>(transition).Encode(), actionValue);
149 double targetActionValue = 0;
150 if (!(terminal && i == pending.size() - 1))
151 targetActionValue = actionValue[std::get<4>(transition)];
152 targetActionValue = std::get<2>(transition) +
153 config.
Discount() * targetActionValue;
156 network.Forward(std::get<0>(transition).Encode(), actionValue);
157 actionValue[std::get<1>(transition)] = targetActionValue;
161 network.Backward(actionValue, gradients);
164 totalGradients += gradients;
168 totalGradients.transform(
170 {
return std::min(std::max(gradient, -config.
GradientLimit()),
174 updater.Update(learningNetwork.Parameters(),
178 network = learningNetwork;
187 { targetNetwork = learningNetwork; }
194 totalReward = episodeReturn;
212 state = environment.InitialSample();
213 action = ActionType::size;
220 EnvironmentType environment;
232 double episodeReturn;
235 std::vector<TransitionType> pending;
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
size_t StepLimit() const
Get the maximum steps of each episode.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
double GradientLimit() const
Get the limit of update gradient.
typename EnvironmentType::Action ActionType
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.