q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
19 #include "training_config.hpp"
20 
21 namespace mlpack {
22 namespace rl {
23 
50 template <
51  typename EnvironmentType,
52  typename NetworkType,
53  typename UpdaterType,
54  typename PolicyType,
55  typename ReplayType = RandomReplay<EnvironmentType>
56 >
57 class QLearning
58 {
59  public:
61  using StateType = typename EnvironmentType::State;
62 
64  using ActionType = typename EnvironmentType::Action;
65 
80  NetworkType network,
81  PolicyType policy,
82  ReplayType replayMethod,
83  UpdaterType updater = UpdaterType(),
84  EnvironmentType environment = EnvironmentType());
85 
90  double Step();
91 
96  double Episode();
97 
101  const size_t& TotalSteps() const { return totalSteps; }
102 
104  StateType& State() { return state; }
106  const StateType& State() const { return state; }
107 
109  EnvironmentType& Environment() { return environment; }
111  const EnvironmentType& Environment() const { return environment; }
112 
114  bool& Deterministic() { return deterministic; }
116  const bool& Deterministic() const { return deterministic; }
117 
119  const NetworkType& Network() const { return learningNetwork; }
121  NetworkType& Network() { return learningNetwork; }
122 
123  private:
129  arma::Col<size_t> BestAction(const arma::mat& actionValues);
130 
132  TrainingConfig config;
133 
135  NetworkType learningNetwork;
136 
138  NetworkType targetNetwork;
139 
141  UpdaterType updater;
142 
144  PolicyType policy;
145 
147  ReplayType replayMethod;
148 
150  EnvironmentType environment;
151 
153  size_t totalSteps;
154 
156  StateType state;
157 
159  bool deterministic;
160 };
161 
162 } // namespace rl
163 } // namespace mlpack
164 
165 // Include implementation
166 #include "q_learning_impl.hpp"
167 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:121
.hpp
Definition: add_to_po.hpp:21
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:64
QLearning(TrainingConfig config, NetworkType network, PolicyType policy, ReplayType replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:114
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:109
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:61
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:111
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:119
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:116
const size_t & TotalSteps() const
Definition: q_learning.hpp:101
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:57
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:104
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:106
double Step()
Execute a step in an episode.