one_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepQLearningWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41 
52  const UpdaterType& updater,
53  const EnvironmentType& environment,
54  const TrainingConfig& config,
55  bool deterministic):
56  updater(updater),
57  environment(environment),
58  config(config),
59  deterministic(deterministic),
60  pending(config.UpdateInterval())
61  { Reset(); }
62 
67  void Initialize(NetworkType& learningNetwork)
68  {
69  updater.Initialize(learningNetwork.Parameters().n_rows,
70  learningNetwork.Parameters().n_cols);
71  // Build local network.
72  network = learningNetwork;
73  }
74 
86  bool Step(NetworkType& learningNetwork,
87  NetworkType& targetNetwork,
88  size_t& totalSteps,
89  PolicyType& policy,
90  double& totalReward)
91  {
92  // Interact with the environment.
93  arma::colvec actionValue;
94  network.Predict(state.Encode(), actionValue);
95  ActionType action = policy.Sample(actionValue, deterministic);
96  StateType nextState;
97  double reward = environment.Sample(state, action, nextState);
98  bool terminal = environment.IsTerminal(nextState);
99 
100  episodeReturn += reward;
101  steps++;
102 
103  terminal = terminal || steps >= config.StepLimit();
104  if (deterministic)
105  {
106  if (terminal)
107  {
108  totalReward = episodeReturn;
109  Reset();
110  // Sync with latest learning network.
111  network = learningNetwork;
112  return true;
113  }
114  state = nextState;
115  return false;
116  }
117 
118  #pragma omp atomic
119  totalSteps++;
120 
121  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
122  pendingIndex++;
123 
124  if (terminal || pendingIndex >= config.UpdateInterval())
125  {
126  // Initialize the gradient storage.
127  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
128  learningNetwork.Parameters().n_cols, arma::fill::zeros);
129  for (size_t i = 0; i < pending.size(); ++i)
130  {
131  TransitionType &transition = pending[i];
132 
133  // Compute the target state-action value.
134  arma::colvec actionValue;
135  #pragma omp critical
136  {
137  targetNetwork.Predict(
138  std::get<3>(transition).Encode(), actionValue);
139  };
140  double targetActionValue = actionValue.max();
141  if (terminal && i == pending.size() - 1)
142  targetActionValue = 0;
143  targetActionValue = std::get<2>(transition) +
144  config.Discount() * targetActionValue;
145 
146  // Compute the training target for current state.
147  network.Forward(std::get<0>(transition).Encode(), actionValue);
148  actionValue[std::get<1>(transition)] = targetActionValue;
149 
150  // Compute gradient.
151  arma::mat gradients;
152  network.Backward(actionValue, gradients);
153 
154  // Accumulate gradients.
155  totalGradients += gradients;
156  }
157 
158  // Clamp the accumulated gradients.
159  totalGradients.transform(
160  [&](double gradient)
161  { return std::min(std::max(gradient, -config.GradientLimit()),
162  config.GradientLimit()); });
163 
164  // Perform async update of the global network.
165  updater.Update(learningNetwork.Parameters(),
166  config.StepSize(), totalGradients);
167 
168  // Sync the local network with the global network.
169  network = learningNetwork;
170 
171  pendingIndex = 0;
172  }
173 
174  // Update global target network.
175  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
176  {
177  #pragma omp critical
178  { targetNetwork = learningNetwork; }
179  }
180 
181  policy.Anneal();
182 
183  if (terminal)
184  {
185  totalReward = episodeReturn;
186  Reset();
187  return true;
188  }
189  state = nextState;
190  return false;
191  }
192 
193  private:
197  void Reset()
198  {
199  steps = 0;
200  episodeReturn = 0;
201  pendingIndex = 0;
202  state = environment.InitialSample();
203  }
204 
206  UpdaterType updater;
207 
209  EnvironmentType environment;
210 
212  TrainingConfig config;
213 
215  bool deterministic;
216 
218  size_t steps;
219 
221  double episodeReturn;
222 
224  std::vector<TransitionType> pending;
225 
227  size_t pendingIndex;
228 
230  NetworkType network;
231 
233  StateType state;
234 };
235 
236 } // namespace rl
237 } // namespace mlpack
238 
239 #endif
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
.hpp
Definition: add_to_po.hpp:21
std::tuple< StateType, ActionType, double, StateType > TransitionType
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::Action ActionType
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
double GradientLimit() const
Get the limit of update gradient.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType