random_replay.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
13 #define MLPACK_METHODS_RL_REPLAY_RANDOM_REPLAY_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace rl {
19 
42 template <typename EnvironmentType>
44 {
45  public:
47  using ActionType = typename EnvironmentType::Action;
48 
50  using StateType = typename EnvironmentType::State;
51 
59  RandomReplay(const size_t batchSize,
60  const size_t capacity,
61  const size_t dimension = StateType::dimension) :
62  batchSize(batchSize),
63  capacity(capacity),
64  position(0),
65  states(dimension, capacity),
66  actions(capacity),
67  rewards(capacity),
68  nextStates(dimension, capacity),
69  isTerminal(capacity),
70  full(false)
71  { /* Nothing to do here. */ }
72 
82  void Store(const StateType& state,
83  ActionType action,
84  double reward,
85  const StateType& nextState,
86  bool isEnd)
87  {
88  states.col(position) = state.Encode();
89  actions(position) = action;
90  rewards(position) = reward;
91  nextStates.col(position) = nextState.Encode();
92  isTerminal(position) = isEnd;
93  position++;
94  if (position == capacity)
95  {
96  full = true;
97  position = 0;
98  }
99  }
100 
111  void Sample(arma::mat& sampledStates,
112  arma::icolvec& sampledActions,
113  arma::colvec& sampledRewards,
114  arma::mat& sampledNextStates,
115  arma::icolvec& isTerminal)
116  {
117  size_t upperBound = full ? capacity : position;
118  arma::uvec sampledIndices = arma::randi<arma::uvec>(
119  batchSize, arma::distr_param(0, upperBound - 1));
120 
121  sampledStates = states.cols(sampledIndices);
122  sampledActions = actions.elem(sampledIndices);
123  sampledRewards = rewards.elem(sampledIndices);
124  sampledNextStates = nextStates.cols(sampledIndices);
125  isTerminal = this->isTerminal.elem(sampledIndices);
126  }
127 
133  const size_t& Size()
134  {
135  return full ? capacity : position;
136  }
137 
138  private:
140  size_t batchSize;
141 
143  size_t capacity;
144 
146  size_t position;
147 
149  arma::mat states;
150 
152  arma::icolvec actions;
153 
155  arma::colvec rewards;
156 
158  arma::mat nextStates;
159 
161  arma::icolvec isTerminal;
162 
164  bool full;
165 };
166 
167 } // namespace rl
168 } // namespace mlpack
169 
170 #endif
typename EnvironmentType::Action ActionType
Convenient typedef for action.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
void Store(const StateType &state, ActionType action, double reward, const StateType &nextState, bool isEnd)
Store the given experience.
void Sample(arma::mat &sampledStates, arma::icolvec &sampledActions, arma::colvec &sampledRewards, arma::mat &sampledNextStates, arma::icolvec &isTerminal)
Sample some experiences.
const size_t & Size()
Get the number of transitions in the memory.
Implementation of random experience replay.
RandomReplay(const size_t batchSize, const size_t capacity, const size_t dimension=StateType::dimension)
Construct an instance of random experience replay class.