q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
20 #include "training_config.hpp"
21 
22 namespace mlpack {
23 namespace rl {
24 
51 template <
52  typename EnvironmentType,
53  typename NetworkType,
54  typename UpdaterType,
55  typename PolicyType,
56  typename ReplayType = RandomReplay<EnvironmentType>
57 >
58 class QLearning
59 {
60  public:
62  using StateType = typename EnvironmentType::State;
63 
65  using ActionType = typename EnvironmentType::Action;
66 
80  QLearning(TrainingConfig& config,
81  NetworkType& network,
82  PolicyType& policy,
83  ReplayType& replayMethod,
84  UpdaterType updater = UpdaterType(),
85  EnvironmentType environment = EnvironmentType());
86 
90  ~QLearning();
91 
95  void TrainAgent();
96 
100  void TrainCategoricalAgent();
101 
105  void SelectAction();
106 
111  double Episode();
112 
114  size_t& TotalSteps() { return totalSteps; }
116  const size_t& TotalSteps() const { return totalSteps; }
117 
119  StateType& State() { return state; }
121  const StateType& State() const { return state; }
122 
124  const ActionType& Action() const { return action; }
125 
127  EnvironmentType& Environment() { return environment; }
129  const EnvironmentType& Environment() const { return environment; }
130 
132  bool& Deterministic() { return deterministic; }
134  const bool& Deterministic() const { return deterministic; }
135 
137  const NetworkType& Network() const { return learningNetwork; }
139  NetworkType& Network() { return learningNetwork; }
140 
141  private:
147  arma::Col<size_t> BestAction(const arma::mat& actionValues);
148 
150  TrainingConfig& config;
151 
153  NetworkType& learningNetwork;
154 
156  NetworkType targetNetwork;
157 
159  PolicyType& policy;
160 
162  ReplayType& replayMethod;
163 
165  UpdaterType updater;
166  #if ENS_VERSION_MAJOR >= 2
167  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
168  #endif
169 
171  EnvironmentType environment;
172 
174  size_t totalSteps;
175 
177  StateType state;
178 
180  ActionType action;
181 
183  bool deterministic;
184 };
185 
186 } // namespace rl
187 } // namespace mlpack
188 
189 // Include implementation
190 #include "q_learning_impl.hpp"
191 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:139
Linear algebra utility functions, generally performed on matrices or vectors.
void SelectAction()
Select an action, given an agent.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:65
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
void TrainAgent()
Trains the DQN agent(non-categorical).
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:132
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
size_t & TotalSteps()
Modify total steps from beginning.
Definition: q_learning.hpp:114
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:127
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:62
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:129
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:137
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:134
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: q_learning.hpp:116
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:58
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:119
const ActionType & Action() const
Get the action of the agent.
Definition: q_learning.hpp:124
~QLearning()
Clean memory.
QLearning(TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:121