sac.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_SAC_HPP
14 #define MLPACK_METHODS_RL_SAC_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "replay/random_replay.hpp"
22 #include "training_config.hpp"
23 
24 namespace mlpack {
25 namespace rl {
26 
56 template <
57  typename EnvironmentType,
58  typename QNetworkType,
59  typename PolicyNetworkType,
60  typename UpdaterType,
61  typename ReplayType = RandomReplay<EnvironmentType>
62 >
63 class SAC
64 {
65  public:
67  using StateType = typename EnvironmentType::State;
68 
70  using ActionType = typename EnvironmentType::Action;
71 
88  SAC(TrainingConfig& config,
89  QNetworkType& learningQ1Network,
90  PolicyNetworkType& policyNetwork,
91  ReplayType& replayMethod,
92  UpdaterType qNetworkUpdater = UpdaterType(),
93  UpdaterType policyNetworkUpdater = UpdaterType(),
94  EnvironmentType environment = EnvironmentType());
95 
99  ~SAC();
100 
107  void SoftUpdate(double rho);
108 
112  void Update();
113 
117  void SelectAction();
118 
123  double Episode();
124 
126  size_t& TotalSteps() { return totalSteps; }
128  const size_t& TotalSteps() const { return totalSteps; }
129 
131  StateType& State() { return state; }
133  const StateType& State() const { return state; }
134 
136  const ActionType& Action() const { return action; }
137 
139  bool& Deterministic() { return deterministic; }
141  const bool& Deterministic() const { return deterministic; }
142 
143 
144  private:
146  TrainingConfig& config;
147 
149  QNetworkType& learningQ1Network;
150  QNetworkType learningQ2Network;
151 
153  QNetworkType targetQ1Network;
154  QNetworkType targetQ2Network;
155 
157  PolicyNetworkType& policyNetwork;
158 
160  ReplayType& replayMethod;
161 
163  UpdaterType qNetworkUpdater;
164  #if ENS_VERSION_MAJOR >= 2
165  typename UpdaterType::template Policy<arma::mat, arma::mat>*
166  qNetworkUpdatePolicy;
167  #endif
168 
170  UpdaterType policyNetworkUpdater;
171  #if ENS_VERSION_MAJOR >= 2
172  typename UpdaterType::template Policy<arma::mat, arma::mat>*
173  policyNetworkUpdatePolicy;
174  #endif
175 
177  EnvironmentType environment;
178 
180  size_t totalSteps;
181 
183  StateType state;
184 
186  ActionType action;
187 
189  bool deterministic;
190 
192  mlpack::ann::MeanSquaredError<> lossFunction;
193 };
194 
195 } // namespace rl
196 } // namespace mlpack
197 
198 // Include implementation
199 #include "sac_impl.hpp"
200 #endif
~SAC()
Clean memory.
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: sac.hpp:70
void SelectAction()
Select an action, given an agent.
Linear algebra utility functions, generally performed on matrices or vectors.
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:63
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Episode()
Execute an episode.
SAC(TrainingConfig &config, QNetworkType &learningQ1Network, PolicyNetworkType &policyNetwork, ReplayType &replayMethod, UpdaterType qNetworkUpdater=UpdaterType(), UpdaterType policyNetworkUpdater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the SAC object with given settings.
void Update()
Update the Q and policy networks.
const StateType & State() const
Get the state of the agent.
Definition: sac.hpp:133
size_t & TotalSteps()
Modify total steps from beginning.
Definition: sac.hpp:126
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: sac.hpp:139
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: sac.hpp:128
The mean squared error performance function measures the network&#39;s performance according to the mean ...
const ActionType & Action() const
Get the action of the agent.
Definition: sac.hpp:136
StateType & State()
Modify the state of the agent.
Definition: sac.hpp:131
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: sac.hpp:141
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:67