one_step_sarsa_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepSarsaWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType,
41  ActionType>;
42 
53  const UpdaterType& updater,
54  const EnvironmentType& environment,
55  const TrainingConfig& config,
56  bool deterministic):
57  updater(updater),
58  environment(environment),
59  config(config),
60  deterministic(deterministic),
61  pending(config.UpdateInterval())
62  { Reset(); }
63 
68  void Initialize(NetworkType& learningNetwork)
69  {
70  updater.Initialize(learningNetwork.Parameters().n_rows,
71  learningNetwork.Parameters().n_cols);
72  // Build local network.
73  network = learningNetwork;
74  }
75 
87  bool Step(NetworkType& learningNetwork,
88  NetworkType& targetNetwork,
89  size_t& totalSteps,
90  PolicyType& policy,
91  double& totalReward)
92  {
93  // Interact with the environment.
94  if (action == ActionType::size)
95  {
96  // Invalid action means we are at the beginning of an episode.
97  arma::colvec actionValue;
98  network.Predict(state.Encode(), actionValue);
99  action = policy.Sample(actionValue, deterministic);
100  }
101  StateType nextState;
102  double reward = environment.Sample(state, action, nextState);
103  bool terminal = environment.IsTerminal(nextState);
104  arma::colvec actionValue;
105  network.Predict(nextState.Encode(), actionValue);
106  ActionType nextAction = policy.Sample(actionValue, deterministic);
107 
108  episodeReturn += reward;
109  steps++;
110 
111  terminal = terminal || steps >= config.StepLimit();
112  if (deterministic)
113  {
114  if (terminal)
115  {
116  totalReward = episodeReturn;
117  Reset();
118  // Sync with latest learning network.
119  network = learningNetwork;
120  return true;
121  }
122  state = nextState;
123  action = nextAction;
124  return false;
125  }
126 
127  #pragma omp atomic
128  totalSteps++;
129 
130  pending[pendingIndex++] =
131  std::make_tuple(state, action, reward, nextState, nextAction);
132 
133  if (terminal || pendingIndex >= config.UpdateInterval())
134  {
135  // Initialize the gradient storage.
136  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
137  learningNetwork.Parameters().n_cols, arma::fill::zeros);
138  for (size_t i = 0; i < pending.size(); ++i)
139  {
140  TransitionType &transition = pending[i];
141 
142  // Compute the target state-action value.
143  arma::colvec actionValue;
144  #pragma omp critical
145  {
146  targetNetwork.Predict(
147  std::get<3>(transition).Encode(), actionValue);
148  };
149  double targetActionValue = 0;
150  if (!(terminal && i == pending.size() - 1))
151  targetActionValue = actionValue[std::get<4>(transition)];
152  targetActionValue = std::get<2>(transition) +
153  config.Discount() * targetActionValue;
154 
155  // Compute the training target for current state.
156  network.Forward(std::get<0>(transition).Encode(), actionValue);
157  actionValue[std::get<1>(transition)] = targetActionValue;
158 
159  // Compute gradient.
160  arma::mat gradients;
161  network.Backward(actionValue, gradients);
162 
163  // Accumulate gradients.
164  totalGradients += gradients;
165  }
166 
167  // Clamp the accumulated gradients.
168  totalGradients.transform(
169  [&](double gradient)
170  { return std::min(std::max(gradient, -config.GradientLimit()),
171  config.GradientLimit()); });
172 
173  // Perform async update of the global network.
174  updater.Update(learningNetwork.Parameters(),
175  config.StepSize(), totalGradients);
176 
177  // Sync the local network with the global network.
178  network = learningNetwork;
179 
180  pendingIndex = 0;
181  }
182 
183  // Update global target network.
184  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
185  {
186  #pragma omp critical
187  { targetNetwork = learningNetwork; }
188  }
189 
190  policy.Anneal();
191 
192  if (terminal)
193  {
194  totalReward = episodeReturn;
195  Reset();
196  return true;
197  }
198  state = nextState;
199  action = nextAction;
200  return false;
201  }
202 
203  private:
207  void Reset()
208  {
209  steps = 0;
210  episodeReturn = 0;
211  pendingIndex = 0;
212  state = environment.InitialSample();
213  action = ActionType::size;
214  }
215 
217  UpdaterType updater;
218 
220  EnvironmentType environment;
221 
223  TrainingConfig config;
224 
226  bool deterministic;
227 
229  size_t steps;
230 
232  double episodeReturn;
233 
235  std::vector<TransitionType> pending;
236 
238  size_t pendingIndex;
239 
241  NetworkType network;
242 
244  StateType state;
245 
247  ActionType action;
248 };
249 
250 } // namespace rl
251 } // namespace mlpack
252 
253 #endif
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
.hpp
Definition: add_to_po.hpp:21
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
size_t StepLimit() const
Get the maximum steps of each episode.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
double GradientLimit() const
Get the limit of update gradient.
typename EnvironmentType::Action ActionType
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.