one_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepQLearningWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41 
52  const UpdaterType& updater,
53  const EnvironmentType& environment,
54  const TrainingConfig& config,
55  bool deterministic):
56  updater(updater),
57  #if ENS_VERSION_MAJOR >= 2
58  updatePolicy(NULL),
59  #endif
60  environment(environment),
61  config(config),
62  deterministic(deterministic),
63  pending(config.UpdateInterval())
64  { Reset(); }
65 
72  updater(other.updater),
73  #if ENS_VERSION_MAJOR >= 2
74  updatePolicy(NULL),
75  #endif
76  environment(other.environment),
77  config(other.config),
78  deterministic(other.deterministic),
79  steps(other.steps),
80  episodeReturn(other.episodeReturn),
81  pending(other.pending),
82  pendingIndex(other.pendingIndex),
83  network(other.network),
84  state(other.state)
85  {
86  #if ENS_VERSION_MAJOR >= 2
87  updatePolicy = new typename UpdaterType::template
88  Policy<arma::mat, arma::mat>(updater,
89  network.Parameters().n_rows,
90  network.Parameters().n_cols);
91  #endif
92 
93  Reset();
94  }
95 
102  updater(std::move(other.updater)),
103  #if ENS_VERSION_MAJOR >= 2
104  updatePolicy(NULL),
105  #endif
106  environment(std::move(other.environment)),
107  config(std::move(other.config)),
108  deterministic(std::move(other.deterministic)),
109  steps(std::move(other.steps)),
110  episodeReturn(std::move(other.episodeReturn)),
111  pending(std::move(other.pending)),
112  pendingIndex(std::move(other.pendingIndex)),
113  network(std::move(other.network)),
114  state(std::move(other.state))
115  {
116  #if ENS_VERSION_MAJOR >= 2
117  other.updatePolicy = NULL;
118 
119  updatePolicy = new typename UpdaterType::template
120  Policy<arma::mat, arma::mat>(updater,
121  network.Parameters().n_rows,
122  network.Parameters().n_cols);
123  #endif
124  }
125 
132  {
133  if (&other == this)
134  return *this;
135 
136  #if ENS_VERSION_MAJOR >= 2
137  delete updatePolicy;
138  #endif
139 
140  updater = other.updater;
141  environment = other.environment;
142  config = other.config;
143  deterministic = other.deterministic;
144  steps = other.steps;
145  episodeReturn = other.episodeReturn;
146  pending = other.pending;
147  pendingIndex = other.pendingIndex;
148  network = other.network;
149  state = other.state;
150 
151  #if ENS_VERSION_MAJOR >= 2
152  updatePolicy = new typename UpdaterType::template
153  Policy<arma::mat, arma::mat>(updater,
154  network.Parameters().n_rows,
155  network.Parameters().n_cols);
156  #endif
157 
158  Reset();
159 
160  return *this;
161  }
162 
169  {
170  if (&other == this)
171  return *this;
172 
173  #if ENS_VERSION_MAJOR >= 2
174  delete updatePolicy;
175  #endif
176 
177  updater = std::move(other.updater);
178  environment = std::move(other.environment);
179  config = std::move(other.config);
180  deterministic = std::move(other.deterministic);
181  steps = std::move(other.steps);
182  episodeReturn = std::move(other.episodeReturn);
183  pending = std::move(other.pending);
184  pendingIndex = std::move(other.pendingIndex);
185  network = std::move(other.network);
186  state = std::move(other.state);
187 
188  #if ENS_VERSION_MAJOR >= 2
189  other.updatePolicy = NULL;
190 
191  updatePolicy = new typename UpdaterType::template
192  Policy<arma::mat, arma::mat>(updater,
193  network.Parameters().n_rows,
194  network.Parameters().n_cols);
195  #endif
196 
197  return *this;
198  }
199 
204  {
205  #if ENS_VERSION_MAJOR >= 2
206  delete updatePolicy;
207  #endif
208  }
209 
214  void Initialize(NetworkType& learningNetwork)
215  {
216  #if ENS_VERSION_MAJOR == 1
217  updater.Initialize(learningNetwork.Parameters().n_rows,
218  learningNetwork.Parameters().n_cols);
219  #else
220  delete updatePolicy;
221 
222  updatePolicy = new typename UpdaterType::template
223  Policy<arma::mat, arma::mat>(updater,
224  learningNetwork.Parameters().n_rows,
225  learningNetwork.Parameters().n_cols);
226  #endif
227 
228  // Build local network.
229  network = learningNetwork;
230  }
231 
243  bool Step(NetworkType& learningNetwork,
244  NetworkType& targetNetwork,
245  size_t& totalSteps,
246  PolicyType& policy,
247  double& totalReward)
248  {
249  // Interact with the environment.
250  arma::colvec actionValue;
251  network.Predict(state.Encode(), actionValue);
252  ActionType action = policy.Sample(actionValue, deterministic);
253  StateType nextState;
254  double reward = environment.Sample(state, action, nextState);
255  bool terminal = environment.IsTerminal(nextState);
256 
257  episodeReturn += reward;
258  steps++;
259 
260  terminal = terminal || steps >= config.StepLimit();
261  if (deterministic)
262  {
263  if (terminal)
264  {
265  totalReward = episodeReturn;
266  Reset();
267  // Sync with latest learning network.
268  network = learningNetwork;
269  return true;
270  }
271  state = nextState;
272  return false;
273  }
274 
275  #pragma omp atomic
276  totalSteps++;
277 
278  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
279  pendingIndex++;
280 
281  if (terminal || pendingIndex >= config.UpdateInterval())
282  {
283  // Initialize the gradient storage.
284  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285  learningNetwork.Parameters().n_cols, arma::fill::zeros);
286  for (size_t i = 0; i < pending.size(); ++i)
287  {
288  TransitionType &transition = pending[i];
289 
290  // Compute the target state-action value.
291  arma::colvec actionValue;
292  #pragma omp critical
293  {
294  targetNetwork.Predict(
295  std::get<3>(transition).Encode(), actionValue);
296  };
297  double targetActionValue = actionValue.max();
298  if (terminal && i == pending.size() - 1)
299  targetActionValue = 0;
300  targetActionValue = std::get<2>(transition) +
301  config.Discount() * targetActionValue;
302 
303  // Compute the training target for current state.
304  arma::mat input = std::get<0>(transition).Encode();
305  network.Forward(input, actionValue);
306  actionValue[std::get<1>(transition).action] = targetActionValue;
307 
308  // Compute gradient.
309  arma::mat gradients;
310  network.Backward(input, actionValue, gradients);
311 
312  // Accumulate gradients.
313  totalGradients += gradients;
314  }
315 
316  // Clamp the accumulated gradients.
317  totalGradients.transform(
318  [&](double gradient)
319  { return std::min(std::max(gradient, -config.GradientLimit()),
320  config.GradientLimit()); });
321 
322  // Perform async update of the global network.
323  #if ENS_VERSION_MAJOR == 1
324  updater.Update(learningNetwork.Parameters(), config.StepSize(),
325  totalGradients);
326  #else
327  updatePolicy->Update(learningNetwork.Parameters(),
328  config.StepSize(), totalGradients);
329  #endif
330 
331  // Sync the local network with the global network.
332  network = learningNetwork;
333 
334  pendingIndex = 0;
335  }
336 
337  // Update global target network.
338  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
339  {
340  #pragma omp critical
341  { targetNetwork = learningNetwork; }
342  }
343 
344  policy.Anneal();
345 
346  if (terminal)
347  {
348  totalReward = episodeReturn;
349  Reset();
350  return true;
351  }
352  state = nextState;
353  return false;
354  }
355 
356  private:
360  void Reset()
361  {
362  steps = 0;
363  episodeReturn = 0;
364  pendingIndex = 0;
365  state = environment.InitialSample();
366  }
367 
369  UpdaterType updater;
370  #if ENS_VERSION_MAJOR >= 2
371  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
372  #endif
373 
375  EnvironmentType environment;
376 
378  TrainingConfig config;
379 
381  bool deterministic;
382 
384  size_t steps;
385 
387  double episodeReturn;
388 
390  std::vector<TransitionType> pending;
391 
393  size_t pendingIndex;
394 
396  NetworkType network;
397 
399  StateType state;
400 };
401 
402 } // namespace rl
403 } // namespace mlpack
404 
405 #endif
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
Linear algebra utility functions, generally performed on matrices or vectors.
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() if(NOT BUILD_GO_BINDINGS) not_found_return("Not building Go bindings.") endif() if(BUILD_GO_BINDINGS) find_package(Go 1.11.0) if(NOT GO_FOUND) set(GO_NOT_FOUND_MSG "$
Definition: CMakeLists.txt:3
Definition: prereqs.hpp:67
std::tuple< StateType, ActionType, double, StateType > TransitionType
n Go endif() find_package(Gonum) if(NOT GONUM_FOUND) set(GO_NOT_FOUND_MSG "$
Definition: CMakeLists.txt:23
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
Forward declaration of OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType