prioritized_replay.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "sumtree.hpp"
17 
18 namespace mlpack {
19 namespace rl {
20 
38 template <typename EnvironmentType>
40 {
41  public:
43  using ActionType = typename EnvironmentType::Action;
44 
46  using StateType = typename EnvironmentType::State;
47 
52  { /* Nothing to do here. */ }
53 
62  PrioritizedReplay(const size_t batchSize,
63  const size_t capacity,
64  const double alpha,
65  const size_t dimension = StateType::dimension) :
66  batchSize(batchSize),
67  capacity(capacity),
68  position(0),
69  states(dimension, capacity),
70  actions(capacity),
71  rewards(capacity),
72  nextStates(dimension, capacity),
73  isTerminal(capacity),
74  full(false),
75  alpha(alpha),
76  maxPriority(1.0),
77  initialBeta(0.6),
78  replayBetaIters(10000)
79  {
80  size_t size = 1;
81  while (size < capacity)
82  {
83  size *= 2;
84  }
85 
86  beta = initialBeta;
87  idxSum = SumTree<double>(size);
88  }
89 
99  void Store(const StateType& state,
100  ActionType action,
101  double reward,
102  const StateType& nextState,
103  bool isEnd)
104  {
105  states.col(position) = state.Encode();
106  actions(position) = action;
107  rewards(position) = reward;
108  nextStates.col(position) = nextState.Encode();
109  isTerminal(position) = isEnd;
110 
111  idxSum.Set(position, maxPriority * alpha);
112 
113  position++;
114  if (position == capacity)
115  {
116  full = true;
117  position = 0;
118  }
119  }
120 
126  arma::ucolvec SampleProportional()
127  {
128  arma::ucolvec idxes(batchSize);
129  double totalSum = idxSum.Sum(0, (full ? capacity : position));
130  double sumPerRange = totalSum / batchSize;
131  for (size_t bt = 0; bt < batchSize; bt++)
132  {
133  const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
134  idxes(bt) = idxSum.FindPrefixSum(mass);
135  }
136  return idxes;
137  }
138 
149  void Sample(arma::mat& sampledStates,
150  arma::icolvec& sampledActions,
151  arma::colvec& sampledRewards,
152  arma::mat& sampledNextStates,
153  arma::icolvec& isTerminal)
154  {
155  sampledIndices = SampleProportional();
156  BetaAnneal();
157 
158  sampledStates = states.cols(sampledIndices);
159  sampledActions = actions.elem(sampledIndices);
160  sampledRewards = rewards.elem(sampledIndices);
161  sampledNextStates = nextStates.cols(sampledIndices);
162  isTerminal = this->isTerminal.elem(sampledIndices);
163 
164  // Calculate the weights of sampled transitions.
165 
166  size_t numSample = full ? capacity : position;
167  weights = arma::rowvec(sampledIndices.n_rows);
168 
169  for (size_t i = 0; i < sampledIndices.n_rows; i++)
170  {
171  double p_sample = idxSum.Get(sampledIndices(i)) / idxSum.Sum();
172  weights(i) = pow(numSample * p_sample, -beta);
173  }
174  weights /= weights.max();
175  }
176 
183  void UpdatePriorities(arma::ucolvec& indices, arma::colvec& priorities)
184  {
185  arma::colvec alphaPri = alpha * priorities;
186  maxPriority = std::max(maxPriority, arma::max(priorities));
187  idxSum.BatchUpdate(indices, alphaPri);
188  }
189 
195  const size_t& Size()
196  {
197  return full ? capacity : position;
198  }
199 
203  void BetaAnneal()
204  {
205  beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
206  }
207 
216  void Update(arma::mat target,
217  arma::icolvec sampledActions,
218  arma::mat nextActionValues,
219  arma::mat& gradients)
220  {
221  arma::colvec tdError(target.n_cols);
222  for (size_t i = 0; i < target.n_cols; i ++)
223  {
224  tdError(i) = nextActionValues(sampledActions(i), i) -
225  target(sampledActions(i), i);
226  }
227  tdError = arma::abs(tdError);
228  UpdatePriorities(sampledIndices, tdError);
229 
230  // Update the gradient
231  gradients = arma::mean(weights) * gradients;
232  }
233 
234 
235  private:
237  size_t batchSize;
238 
240  size_t capacity;
241 
243  size_t position;
244 
246  arma::mat states;
247 
249  arma::icolvec actions;
250 
252  arma::colvec rewards;
253 
255  arma::mat nextStates;
256 
258  arma::icolvec isTerminal;
259 
261  bool full;
262 
265  double alpha;
266 
268  double maxPriority;
269 
271  double initialBeta;
272 
274  double beta;
275 
277  size_t replayBetaIters;
278 
280  SumTree<double> idxSum;
281 
283  arma::ucolvec sampledIndices;
284 
286  arma::rowvec weights;
287 };
288 
289 } // namespace rl
290 } // namespace mlpack
291 
292 #endif
void BetaAnneal()
Annealing the beta.
strip_type.hpp
Definition: add_to_po.hpp:21
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experience according to their priorities.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience and set the priorities for the given experience.
The core includes that mlpack expects; standard C++ includes and Armadillo.
T Get(size_t idx)
Get the data array with idx.
Definition: sumtree.hpp:93
void BatchUpdate(const arma::ucolvec &indices, const arma::Col< T > &data)
Update the data with batch rather loop over the indices with set method.
Definition: sumtree.hpp:75
arma::ucolvec SampleProportional()
Sample some experience according to their priorities.
Implementation of prioritized experience replay.
size_t FindPrefixSum(T mass)
Find the highest index idx in the array such that sum(arr[0] + arr[1] + ...
Definition: sumtree.hpp:163
T Sum(const size_t start, size_t end)
Calculate the sum of contiguous subsequence of the array.
Definition: sumtree.hpp:143
const size_t & Size()
Get the number of transitions in the memory.
PrioritizedReplay()
Default constructor.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
void Set(size_t idx, const T value)
Set the data array with idx.
Definition: sumtree.hpp:57
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions.
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Update(arma::mat target, arma::icolvec sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients.