q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
19 #include "training_config.hpp"
20 
21 namespace mlpack {
22 namespace rl {
23 
50 template <
51  typename EnvironmentType,
52  typename NetworkType,
53  typename UpdaterType,
54  typename PolicyType,
55  typename ReplayType = RandomReplay<EnvironmentType>
56 >
57 class QLearning
58 {
59  public:
61  using StateType = typename EnvironmentType::State;
62 
64  using ActionType = typename EnvironmentType::Action;
65 
80  NetworkType network,
81  PolicyType policy,
82  ReplayType replayMethod,
83  UpdaterType updater = UpdaterType(),
84  EnvironmentType environment = EnvironmentType());
85 
90  double Step();
91 
96  double Episode();
97 
101  const size_t& TotalSteps() const { return totalSteps; }
102 
104  bool& Deterministic() { return deterministic; }
106  const bool& Deterministic() const { return deterministic; }
107 
109  const NetworkType& Network() const { return learningNetwork; }
111  NetworkType& Network() { return learningNetwork; }
112 
113  private:
119  arma::Col<size_t> BestAction(const arma::mat& actionValues);
120 
122  TrainingConfig config;
123 
125  NetworkType learningNetwork;
126 
128  NetworkType targetNetwork;
129 
131  UpdaterType updater;
132 
134  PolicyType policy;
135 
137  ReplayType replayMethod;
138 
140  EnvironmentType environment;
141 
143  size_t totalSteps;
144 
146  StateType state;
147 
149  bool deterministic;
150 };
151 
152 } // namespace rl
153 } // namespace mlpack
154 
155 // Include implementation
156 #include "q_learning_impl.hpp"
157 #endif
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:111
.hpp
Definition: add_to_po.hpp:21
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:64
QLearning(TrainingConfig config, NetworkType network, PolicyType policy, ReplayType replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:104
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:61
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:109
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:106
const size_t & TotalSteps() const
Definition: q_learning.hpp:101
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:57
double Step()
Execute a step in an episode.