random_replay.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace rl {
19 
42 template <typename EnvironmentType>
44 {
45  public:
47  using ActionType = typename EnvironmentType::Action;
48 
50  using StateType = typename EnvironmentType::State;
51 
53  batchSize(0),
54  capacity(0),
55  position(0),
56  full(false)
57  { /* Nothing to do here. */ }
58 
66  RandomReplay(const size_t batchSize,
67  const size_t capacity,
68  const size_t dimension = StateType::dimension) :
69  batchSize(batchSize),
70  capacity(capacity),
71  position(0),
72  states(dimension, capacity),
73  actions(capacity),
74  rewards(capacity),
75  nextStates(dimension, capacity),
76  isTerminal(capacity),
77  full(false)
78  { /* Nothing to do here. */ }
79 
89  void Store(const StateType& state,
90  ActionType action,
91  double reward,
92  const StateType& nextState,
93  bool isEnd)
94  {
95  states.col(position) = state.Encode();
96  actions(position) = action;
97  rewards(position) = reward;
98  nextStates.col(position) = nextState.Encode();
99  isTerminal(position) = isEnd;
100  position++;
101  if (position == capacity)
102  {
103  full = true;
104  position = 0;
105  }
106  }
107 
118  void Sample(arma::mat& sampledStates,
119  arma::icolvec& sampledActions,
120  arma::colvec& sampledRewards,
121  arma::mat& sampledNextStates,
122  arma::icolvec& isTerminal)
123  {
124  size_t upperBound = full ? capacity : position;
125  arma::uvec sampledIndices = arma::randi<arma::uvec>(
126  batchSize, arma::distr_param(0, upperBound - 1));
127 
128  sampledStates = states.cols(sampledIndices);
129  sampledActions = actions.elem(sampledIndices);
130  sampledRewards = rewards.elem(sampledIndices);
131  sampledNextStates = nextStates.cols(sampledIndices);
132  isTerminal = this->isTerminal.elem(sampledIndices);
133  }
134 
140  const size_t& Size()
141  {
142  return full ? capacity : position;
143  }
144 
153  void Update(arma::mat /* target */,
154  arma::icolvec /* sampledActions */,
155  arma::mat /* nextActionValues */,
156  arma::mat& /* gradients */)
157  {
158  /* Do nothing for random replay. */
159  }
160 
161  private:
163  size_t batchSize;
164 
166  size_t capacity;
167 
169  size_t position;
170 
172  arma::mat states;
173 
175  arma::icolvec actions;
176 
178  arma::colvec rewards;
179 
181  arma::mat nextStates;
182 
184  arma::icolvec isTerminal;
185 
187  bool full;
188 };
189 
190 } // namespace rl
191 } // namespace mlpack
192 
193 #endif
typename EnvironmentType::Action ActionType
Convenient typedef for action.
strip_type.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience.
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
void Update(arma::mat, arma::icolvec, arma::mat, arma::mat &)
Update the priorities of transitions and Update the gradients.
Implementation of random experience replay.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.