12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 42 template <
typename EnvironmentType>
67 const size_t capacity,
68 const size_t dimension = StateType::dimension) :
72 states(dimension, capacity),
75 nextStates(dimension, capacity),
95 states.col(position) = state.Encode();
96 actions(position) = action;
97 rewards(position) = reward;
98 nextStates.col(position) = nextState.Encode();
99 isTerminal(position) = isEnd;
101 if (position == capacity)
119 arma::icolvec& sampledActions,
120 arma::colvec& sampledRewards,
121 arma::mat& sampledNextStates,
122 arma::icolvec& isTerminal)
124 size_t upperBound = full ? capacity : position;
125 arma::uvec sampledIndices = arma::randi<arma::uvec>(
126 batchSize, arma::distr_param(0, upperBound - 1));
128 sampledStates = states.cols(sampledIndices);
129 sampledActions = actions.elem(sampledIndices);
130 sampledRewards = rewards.elem(sampledIndices);
131 sampledNextStates = nextStates.cols(sampledIndices);
132 isTerminal = this->isTerminal.elem(sampledIndices);
142 return full ? capacity : position;
175 arma::icolvec actions;
178 arma::colvec rewards;
181 arma::mat nextStates;
184 arma::icolvec isTerminal;
typename EnvironmentType::Action ActionType
Convenient typedef for action.
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience.
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
void Update(arma::mat, arma::icolvec, arma::mat, arma::mat &)
Update the priorities of transitions and Update the gradients.
Implementation of random experience replay.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.