q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
20 #include "training_config.hpp"
21 
22 namespace mlpack {
23 namespace rl {
24 
51 template <
52  typename EnvironmentType,
53  typename NetworkType,
54  typename UpdaterType,
55  typename PolicyType,
56  typename ReplayType = RandomReplay<EnvironmentType>
57 >
58 class QLearning
59 {
60  public:
62  using StateType = typename EnvironmentType::State;
63 
65  using ActionType = typename EnvironmentType::Action;
66 
80  QLearning(TrainingConfig& config,
81  NetworkType& network,
82  PolicyType& policy,
83  ReplayType& replayMethod,
84  UpdaterType updater = UpdaterType(),
85  EnvironmentType environment = EnvironmentType());
86 
90  ~QLearning();
91 
96  double Step();
97 
102  double Episode();
103 
107  const size_t& TotalSteps() const { return totalSteps; }
108 
110  StateType& State() { return state; }
112  const StateType& State() const { return state; }
113 
115  const ActionType& Action() const { return action; }
116 
118  EnvironmentType& Environment() { return environment; }
120  const EnvironmentType& Environment() const { return environment; }
121 
123  bool& Deterministic() { return deterministic; }
125  const bool& Deterministic() const { return deterministic; }
126 
128  const NetworkType& Network() const { return learningNetwork; }
130  NetworkType& Network() { return learningNetwork; }
131 
132  private:
138  arma::Col<size_t> BestAction(const arma::mat& actionValues);
139 
141  TrainingConfig& config;
142 
144  NetworkType& learningNetwork;
145 
147  NetworkType targetNetwork;
148 
150  PolicyType& policy;
151 
153  ReplayType& replayMethod;
154 
156  UpdaterType updater;
157  #if ENS_VERSION_MAJOR >= 2
158  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
159  #endif
160 
162  EnvironmentType environment;
163 
165  size_t totalSteps;
166 
168  StateType state;
169 
171  ActionType action;
172 
174  bool deterministic;
175 };
176 
177 } // namespace rl
178 } // namespace mlpack
179 
180 // Include implementation
181 #include "q_learning_impl.hpp"
182 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:130
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:65
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:123
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:118
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:62
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:120
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:128
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:125
const size_t & TotalSteps() const
Definition: q_learning.hpp:107
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:58
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:110
const ActionType & Action() const
Get the action of the agent.
Definition: q_learning.hpp:115
~QLearning()
Clean memory.
QLearning(TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:112
double Step()
Execute a step in an episode.