12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP 42 template <
typename EnvironmentType>
60 const size_t capacity,
61 const size_t dimension = StateType::dimension) :
65 states(dimension, capacity),
68 nextStates(dimension, capacity),
88 states.col(position) = state.Encode();
89 actions(position) = action;
90 rewards(position) = reward;
91 nextStates.col(position) = nextState.Encode();
92 isTerminal(position) = isEnd;
94 if (position == capacity)
112 arma::icolvec& sampledActions,
113 arma::colvec& sampledRewards,
114 arma::mat& sampledNextStates,
115 arma::icolvec& isTerminal)
117 size_t upperBound = full ? capacity : position;
118 arma::uvec sampledIndices = arma::randi<arma::uvec>(
119 batchSize, arma::distr_param(0, upperBound - 1));
121 sampledStates = states.cols(sampledIndices);
122 sampledActions = actions.elem(sampledIndices);
123 sampledRewards = rewards.elem(sampledIndices);
124 sampledNextStates = nextStates.cols(sampledIndices);
125 isTerminal = this->isTerminal.elem(sampledIndices);
135 return full ? capacity : position;
152 arma::icolvec actions;
155 arma::colvec rewards;
158 arma::mat nextStates;
161 arma::icolvec isTerminal;
typename EnvironmentType::Action ActionType
Convenient typedef for action.
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience.
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
Implementation of random experience replay.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.