q_learning.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
19 #include "training_config.hpp"
20 
21 namespace mlpack {
22 namespace rl {
23 
50 template <
51  typename EnvironmentType,
52  typename NetworkType,
53  typename UpdaterType,
54  typename PolicyType,
55  typename ReplayType = RandomReplay<EnvironmentType>
56 >
57 class QLearning
58 {
59  public:
61  using StateType = typename EnvironmentType::State;
62 
64  using ActionType = typename EnvironmentType::Action;
65 
80  NetworkType network,
81  PolicyType policy,
82  ReplayType replayMethod,
83  UpdaterType updater = UpdaterType(),
84  EnvironmentType environment = EnvironmentType());
85 
90  double Step();
91 
96  double Episode();
97 
101  const size_t& TotalSteps() const { return totalSteps; }
102 
104  bool& Deterministic() { return deterministic; }
106  const bool& Deterministic() const { return deterministic; }
107 
108  private:
114  arma::Col<size_t> BestAction(const arma::mat& actionValues);
115 
117  TrainingConfig config;
118 
120  NetworkType learningNetwork;
121 
123  NetworkType targetNetwork;
124 
126  UpdaterType updater;
127 
129  PolicyType policy;
130 
132  ReplayType replayMethod;
133 
135  EnvironmentType environment;
136 
138  size_t totalSteps;
139 
141  StateType state;
142 
144  bool deterministic;
145 };
146 
147 } // namespace rl
148 } // namespace mlpack
149 
150 // Include implementation
151 #include "q_learning_impl.hpp"
152 #endif
.hpp
Definition: add_to_po.hpp:21
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:64
QLearning(TrainingConfig config, NetworkType network, PolicyType policy, ReplayType replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:104
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:61
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:106
const size_t & TotalSteps() const
Definition: q_learning.hpp:101
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:57
double Step()
Execute a step in an episode.