one_step_sarsa_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepSarsaWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType,
41  ActionType>;
42 
53  const UpdaterType& updater,
54  const EnvironmentType& environment,
55  const TrainingConfig& config,
56  bool deterministic):
57  updater(updater),
58  #if ENS_VERSION_MAJOR >= 2
59  updatePolicy(NULL),
60  #endif
61  environment(environment),
62  config(config),
63  deterministic(deterministic),
64  pending(config.UpdateInterval())
65  { Reset(); }
66 
73  updater(other.updater),
74  #if ENS_VERSION_MAJOR >= 2
75  updatePolicy(NULL),
76  #endif
77  environment(other.environment),
78  config(other.config),
79  deterministic(other.deterministic),
80  steps(other.steps),
81  episodeReturn(other.episodeReturn),
82  pending(other.pending),
83  pendingIndex(other.pendingIndex),
84  network(other.network),
85  state(other.state),
86  action(other.action)
87  {
88  Reset();
89 
90  #if ENS_VERSION_MAJOR >= 2
91  updatePolicy = new typename UpdaterType::template
92  Policy<arma::mat, arma::mat>(updater,
93  network.Parameters().n_rows,
94  network.Parameters().n_cols);
95  #endif
96  }
97 
104  updater(std::move(other.updater)),
105  #if ENS_VERSION_MAJOR >= 2
106  updatePolicy(NULL),
107  #endif
108  environment(std::move(other.environment)),
109  config(std::move(other.config)),
110  deterministic(std::move(other.deterministic)),
111  steps(std::move(other.steps)),
112  episodeReturn(std::move(other.episodeReturn)),
113  pending(std::move(other.pending)),
114  pendingIndex(std::move(other.pendingIndex)),
115  network(std::move(other.network)),
116  state(std::move(other.state)),
117  action(std::move(other.action))
118  {
119  #if ENS_VERSION_MAJOR >= 2
120  other.updatePolicy = NULL;
121 
122  updatePolicy = new typename UpdaterType::template
123  Policy<arma::mat, arma::mat>(updater,
124  network.Parameters().n_rows,
125  network.Parameters().n_cols);
126  #endif
127  }
128 
135  {
136  if (&other == this)
137  return *this;
138 
139  #if ENS_VERSION_MAJOR >= 2
140  delete updatePolicy;
141  #endif
142 
143  updater = other.updater;
144  environment = other.environment;
145  config = other.config;
146  deterministic = other.deterministic;
147  steps = other.steps;
148  episodeReturn = other.episodeReturn;
149  pending = other.pending;
150  pendingIndex = other.pendingIndex;
151  network = other.network;
152  state = other.state;
153  action = other.action;
154 
155  #if ENS_VERSION_MAJOR >= 2
156  updatePolicy = new typename UpdaterType::template
157  Policy<arma::mat, arma::mat>(updater,
158  network.Parameters().n_rows,
159  network.Parameters().n_cols);
160  #endif
161 
162  Reset();
163 
164  return *this;
165  }
166 
173  {
174  if (&other == this)
175  return *this;
176 
177  #if ENS_VERSION_MAJOR >= 2
178  delete updatePolicy;
179  #endif
180 
181  updater = std::move(other.updater);
182  environment = std::move(other.environment);
183  config = std::move(other.config);
184  deterministic = std::move(other.deterministic);
185  steps = std::move(other.steps);
186  episodeReturn = std::move(other.episodeReturn);
187  pending = std::move(other.pending);
188  pendingIndex = std::move(other.pendingIndex);
189  network = std::move(other.network);
190  state = std::move(other.state);
191  action = std::move(other.action);
192 
193  #if ENS_VERSION_MAJOR >= 2
194  other.updatePolicy = NULL;
195 
196  updatePolicy = new typename UpdaterType::template
197  Policy<arma::mat, arma::mat>(updater,
198  network.Parameters().n_rows,
199  network.Parameters().n_cols);
200  #endif
201 
202  return *this;
203  }
204 
209  {
210  #if ENS_VERSION_MAJOR >= 2
211  delete updatePolicy;
212  #endif
213  }
214 
219  void Initialize(NetworkType& learningNetwork)
220  {
221  #if ENS_VERSION_MAJOR == 1
222  updater.Initialize(learningNetwork.Parameters().n_rows,
223  learningNetwork.Parameters().n_cols);
224  #else
225  delete updatePolicy;
226 
227  updatePolicy = new typename UpdaterType::template
228  Policy<arma::mat, arma::mat>(updater,
229  learningNetwork.Parameters().n_rows,
230  learningNetwork.Parameters().n_cols);
231  #endif
232 
233  // Build local network.
234  network = learningNetwork;
235  }
236 
248  bool Step(NetworkType& learningNetwork,
249  NetworkType& targetNetwork,
250  size_t& totalSteps,
251  PolicyType& policy,
252  double& totalReward)
253  {
254  // Interact with the environment.
255  if (action == ActionType::size)
256  {
257  // Invalid action means we are at the beginning of an episode.
258  arma::colvec actionValue;
259  network.Predict(state.Encode(), actionValue);
260  action = policy.Sample(actionValue, deterministic);
261  }
262  StateType nextState;
263  double reward = environment.Sample(state, action, nextState);
264  bool terminal = environment.IsTerminal(nextState);
265  arma::colvec actionValue;
266  network.Predict(nextState.Encode(), actionValue);
267  ActionType nextAction = policy.Sample(actionValue, deterministic);
268 
269  episodeReturn += reward;
270  steps++;
271 
272  terminal = terminal || steps >= config.StepLimit();
273  if (deterministic)
274  {
275  if (terminal)
276  {
277  totalReward = episodeReturn;
278  Reset();
279  // Sync with latest learning network.
280  network = learningNetwork;
281  return true;
282  }
283  state = nextState;
284  action = nextAction;
285  return false;
286  }
287 
288  #pragma omp atomic
289  totalSteps++;
290 
291  pending[pendingIndex++] =
292  std::make_tuple(state, action, reward, nextState, nextAction);
293 
294  if (terminal || pendingIndex >= config.UpdateInterval())
295  {
296  // Initialize the gradient storage.
297  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
298  learningNetwork.Parameters().n_cols, arma::fill::zeros);
299  for (size_t i = 0; i < pending.size(); ++i)
300  {
301  TransitionType &transition = pending[i];
302 
303  // Compute the target state-action value.
304  arma::colvec actionValue;
305  #pragma omp critical
306  {
307  targetNetwork.Predict(
308  std::get<3>(transition).Encode(), actionValue);
309  };
310  double targetActionValue = 0;
311  if (!(terminal && i == pending.size() - 1))
312  targetActionValue = actionValue[std::get<4>(transition)];
313  targetActionValue = std::get<2>(transition) +
314  config.Discount() * targetActionValue;
315 
316  // Compute the training target for current state.
317  arma::mat input = std::get<0>(transition).Encode();
318  network.Forward(input, actionValue);
319  actionValue[std::get<1>(transition)] = targetActionValue;
320 
321  // Compute gradient.
322  arma::mat gradients;
323  network.Backward(input, actionValue, gradients);
324 
325  // Accumulate gradients.
326  totalGradients += gradients;
327  }
328 
329  // Clamp the accumulated gradients.
330  totalGradients.transform(
331  [&](double gradient)
332  { return std::min(std::max(gradient, -config.GradientLimit()),
333  config.GradientLimit()); });
334 
335  // Perform async update of the global network.
336  #if ENS_VERSION_MAJOR == 1
337  updater.Update(learningNetwork.Parameters(), config.StepSize(),
338  totalGradients);
339  #else
340  updatePolicy->Update(learningNetwork.Parameters(),
341  config.StepSize(), totalGradients);
342  #endif
343 
344  // Sync the local network with the global network.
345  network = learningNetwork;
346 
347  pendingIndex = 0;
348  }
349 
350  // Update global target network.
351  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
352  {
353  #pragma omp critical
354  { targetNetwork = learningNetwork; }
355  }
356 
357  policy.Anneal();
358 
359  if (terminal)
360  {
361  totalReward = episodeReturn;
362  Reset();
363  return true;
364  }
365  state = nextState;
366  action = nextAction;
367  return false;
368  }
369 
370  private:
374  void Reset()
375  {
376  steps = 0;
377  episodeReturn = 0;
378  pendingIndex = 0;
379  state = environment.InitialSample();
380  action = ActionType::size;
381  }
382 
384  UpdaterType updater;
385  #if ENS_VERSION_MAJOR >= 2
386  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
387  #endif
388 
390  EnvironmentType environment;
391 
393  TrainingConfig config;
394 
396  bool deterministic;
397 
399  size_t steps;
400 
402  double episodeReturn;
403 
405  std::vector<TransitionType> pending;
406 
408  size_t pendingIndex;
409 
411  NetworkType network;
412 
414  StateType state;
415 
417  ActionType action;
418 };
419 
420 } // namespace rl
421 } // namespace mlpack
422 
423 #endif
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
Definition: prereqs.hpp:55
size_t StepLimit() const
Get the maximum steps of each episode.
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() if(NOT BUILD_GO_BINDINGS) not_found_return("Not building Go bindings.") endif() if(BUILD_GO_BINDINGS) if(FORCE_BUILD_GO_BINDINGS) find_package(Go 1.11.0) find_package(Gonum) if(NOT GO_FOUND OR NOT GONUM_FOUND) unset(BUILD_GO_BINDINGS CACHE) message(FATAL_ERROR "Go or Gonum not found
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
unable to build Go bindings endif() else() find_package(Go 1.11.0) find_package(Gonum) if(NOT GO_FOUND OR NOT GONUM_FOUND) unset(BUILD_GO_BINDINGS CACHE) endif() endif() if(NOT GO_FOUND) not_found_return("Go not found
Definition: CMakeLists.txt:43
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
double GradientLimit() const
Get the limit of update gradient.
typename EnvironmentType::Action ActionType
Forward declaration of OneStepSarsaWorker.
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.