q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
20 #include "training_config.hpp"
21 
22 namespace mlpack {
23 namespace rl {
24 
51 template <
52  typename EnvironmentType,
53  typename NetworkType,
54  typename UpdaterType,
55  typename PolicyType,
56  typename ReplayType = RandomReplay<EnvironmentType>
57 >
58 class QLearning
59 {
60  public:
62  using StateType = typename EnvironmentType::State;
63 
65  using ActionType = typename EnvironmentType::Action;
66 
81  NetworkType network,
82  PolicyType policy,
83  ReplayType replayMethod,
84  UpdaterType updater = UpdaterType(),
85  EnvironmentType environment = EnvironmentType());
86 
90  ~QLearning();
91 
96  double Step();
97 
102  double Episode();
103 
107  const size_t& TotalSteps() const { return totalSteps; }
108 
110  StateType& State() { return state; }
112  const StateType& State() const { return state; }
113 
115  EnvironmentType& Environment() { return environment; }
117  const EnvironmentType& Environment() const { return environment; }
118 
120  bool& Deterministic() { return deterministic; }
122  const bool& Deterministic() const { return deterministic; }
123 
125  const NetworkType& Network() const { return learningNetwork; }
127  NetworkType& Network() { return learningNetwork; }
128 
129  private:
135  arma::Col<size_t> BestAction(const arma::mat& actionValues);
136 
138  TrainingConfig config;
139 
141  NetworkType learningNetwork;
142 
144  NetworkType targetNetwork;
145 
147  UpdaterType updater;
148  #if ENS_VERSION_MAJOR >= 2
149  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
150  #endif
151 
153  PolicyType policy;
154 
156  ReplayType replayMethod;
157 
159  EnvironmentType environment;
160 
162  size_t totalSteps;
163 
165  StateType state;
166 
168  bool deterministic;
169 };
170 
171 } // namespace rl
172 } // namespace mlpack
173 
174 // Include implementation
175 #include "q_learning_impl.hpp"
176 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:127
strip_type.hpp
Definition: add_to_po.hpp:21
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:65
QLearning(TrainingConfig config, NetworkType network, PolicyType policy, ReplayType replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:120
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:115
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:62
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:117
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:125
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:122
const size_t & TotalSteps() const
Definition: q_learning.hpp:107
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:58
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:110
~QLearning()
Clean memory.
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:112
double Step()
Execute a step in an episode.