13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP 14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP 30 typename EnvironmentType,
35 class OneStepQLearningWorker
52 const UpdaterType& updater,
53 const EnvironmentType& environment,
57 environment(environment),
59 deterministic(deterministic),
60 pending(config.UpdateInterval())
69 updater.Initialize(learningNetwork.Parameters().n_rows,
70 learningNetwork.Parameters().n_cols);
72 network = learningNetwork;
86 bool Step(NetworkType& learningNetwork,
87 NetworkType& targetNetwork,
93 arma::colvec actionValue;
94 network.Predict(state.Encode(), actionValue);
95 ActionType action = policy.Sample(actionValue, deterministic);
97 double reward = environment.Sample(state, action, nextState);
98 bool terminal = environment.IsTerminal(nextState);
100 episodeReturn += reward;
103 terminal = terminal || steps >= config.
StepLimit();
108 totalReward = episodeReturn;
111 network = learningNetwork;
121 pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
127 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
128 learningNetwork.Parameters().n_cols, arma::fill::zeros);
129 for (
size_t i = 0; i < pending.size(); ++i)
134 arma::colvec actionValue;
137 targetNetwork.Predict(
138 std::get<3>(transition).Encode(), actionValue);
140 double targetActionValue = actionValue.max();
141 if (terminal && i == pending.size() - 1)
142 targetActionValue = 0;
143 targetActionValue = std::get<2>(transition) +
144 config.
Discount() * targetActionValue;
147 network.Forward(std::get<0>(transition).Encode(), actionValue);
148 actionValue[std::get<1>(transition)] = targetActionValue;
152 network.Backward(actionValue, gradients);
155 totalGradients += gradients;
159 totalGradients.transform(
161 {
return std::min(std::max(gradient, -config.
GradientLimit()),
165 updater.Update(learningNetwork.Parameters(),
169 network = learningNetwork;
178 { targetNetwork = learningNetwork; }
185 totalReward = episodeReturn;
202 state = environment.InitialSample();
209 EnvironmentType environment;
221 double episodeReturn;
224 std::vector<TransitionType> pending;
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
std::tuple< StateType, ActionType, double, StateType > TransitionType
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::Action ActionType
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
double GradientLimit() const
Get the limit of update gradient.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType