12 #ifndef MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP 13 #define MLPACK_METHODS_RL_PRIORITIZED_REPLAY_HPP 38 template <
typename EnvironmentType>
63 const size_t capacity,
65 const size_t dimension = StateType::dimension) :
69 states(dimension, capacity),
72 nextStates(dimension, capacity),
78 replayBetaIters(10000)
81 while (size < capacity)
105 states.col(position) = state.Encode();
106 actions(position) = action;
107 rewards(position) = reward;
108 nextStates.col(position) = nextState.Encode();
109 isTerminal(position) = isEnd;
111 idxSum.
Set(position, maxPriority * alpha);
114 if (position == capacity)
128 arma::ucolvec idxes(batchSize);
129 double totalSum = idxSum.
Sum(0, (full ? capacity : position));
130 double sumPerRange = totalSum / batchSize;
131 for (
size_t bt = 0; bt < batchSize; bt++)
133 const double mass = arma::randu() * sumPerRange + bt * sumPerRange;
150 arma::icolvec& sampledActions,
151 arma::colvec& sampledRewards,
152 arma::mat& sampledNextStates,
153 arma::icolvec& isTerminal)
158 sampledStates = states.cols(sampledIndices);
159 sampledActions = actions.elem(sampledIndices);
160 sampledRewards = rewards.elem(sampledIndices);
161 sampledNextStates = nextStates.cols(sampledIndices);
162 isTerminal = this->isTerminal.elem(sampledIndices);
166 size_t numSample = full ? capacity : position;
167 weights = arma::rowvec(sampledIndices.n_rows);
169 for (
size_t i = 0; i < sampledIndices.n_rows; i++)
171 double p_sample = idxSum.
Get(sampledIndices(i)) / idxSum.
Sum();
172 weights(i) = pow(numSample * p_sample, -beta);
174 weights /= weights.max();
185 arma::colvec alphaPri = alpha * priorities;
186 maxPriority = std::max(maxPriority, arma::max(priorities));
197 return full ? capacity : position;
205 beta = beta + (1 - initialBeta) * 1.0 / replayBetaIters;
217 arma::icolvec sampledActions,
218 arma::mat nextActionValues,
219 arma::mat& gradients)
221 arma::colvec tdError(target.n_cols);
222 for (
size_t i = 0; i < target.n_cols; i ++)
224 tdError(i) = nextActionValues(sampledActions(i), i) -
225 target(sampledActions(i), i);
227 tdError = arma::abs(tdError);
231 gradients = arma::mean(weights) * gradients;
249 arma::icolvec actions;
252 arma::colvec rewards;
255 arma::mat nextStates;
258 arma::icolvec isTerminal;
277 size_t replayBetaIters;
283 arma::ucolvec sampledIndices;
286 arma::rowvec weights;
void BetaAnneal()
Annealing the beta.
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experience according to their priorities.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience and set the priorities for the given experience.
The core includes that mlpack expects; standard C++ includes and Armadillo.
T Get(size_t idx)
Get the data array with idx.
void BatchUpdate(const arma::ucolvec &indices, const arma::Col< T > &data)
Update the data with batch rather loop over the indices with set method.
arma::ucolvec SampleProportional()
Sample some experience according to their priorities.
Implementation of prioritized experience replay.
size_t FindPrefixSum(T mass)
Find the highest index idx in the array such that sum(arr[0] + arr[1] + ...
T Sum(const size_t start, size_t end)
Calculate the sum of contiguous subsequence of the array.
const size_t & Size()
Get the number of transitions in the memory.
PrioritizedReplay()
Default constructor.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
void Set(size_t idx, const T value)
Set the data array with idx.
void UpdatePriorities(arma::ucolvec &indices, arma::colvec &priorities)
Update priorities of sampled transitions.
PrioritizedReplay(const size_t batchSize, const size_t capacity, const double alpha, const size_t dimension=StateType::dimension)
Construct an instance of prioritized experience replay class.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Update(arma::mat target, arma::icolvec sampledActions, arma::mat nextActionValues, arma::mat &gradients)
Update the priorities of transitions and Update the gradients.