dueling_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP
13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace rl {
24 
25 using namespace mlpack::ann;
26 
46 template <
47  typename CompleteNetworkType = FFN<EmptyLoss<>, GaussianInitialization>,
48  typename FeatureNetworkType = Sequential<>,
49  typename AdvantageNetworkType = Sequential<>,
50  typename ValueNetworkType = Sequential<>
51 >
53 {
54  public:
56  DuelingDQN() : isNoisy(false)
57  {
58  featureNetwork = new Sequential<>();
59  valueNetwork = new Sequential<>();
60  advantageNetwork = new Sequential<>();
61  concat = new Concat<>(true);
62 
63  concat->Add(valueNetwork);
64  concat->Add(advantageNetwork);
65  completeNetwork.Add(new IdentityLayer<>());
66  completeNetwork.Add(featureNetwork);
67  completeNetwork.Add(concat);
68  }
69 
79  DuelingDQN(const int inputDim,
80  const int h1,
81  const int h2,
82  const int outputDim,
83  const bool isNoisy = false):
84  completeNetwork(EmptyLoss<>(), GaussianInitialization(0, 0.001)),
85  isNoisy(isNoisy)
86  {
87  featureNetwork = new Sequential<>();
88  featureNetwork->Add(new Linear<>(inputDim, h1));
89  featureNetwork->Add(new ReLULayer<>());
90 
91  valueNetwork = new Sequential<>();
92  advantageNetwork = new Sequential<>();
93 
94  if (isNoisy)
95  {
96  noisyLayerIndex.push_back(valueNetwork->Model().size());
97  valueNetwork->Add(new NoisyLinear<>(h1, h2));
98  advantageNetwork->Add(new NoisyLinear<>(h1, h2));
99 
100  valueNetwork->Add(new ReLULayer<>());
101  advantageNetwork->Add(new ReLULayer<>());
102 
103  noisyLayerIndex.push_back(valueNetwork->Model().size());
104  valueNetwork->Add(new NoisyLinear<>(h2, 1));
105  advantageNetwork->Add(new NoisyLinear<>(h2, outputDim));
106  }
107  else
108  {
109  valueNetwork->Add(new Linear<>(h1, h2));
110  valueNetwork->Add(new ReLULayer<>());
111  valueNetwork->Add(new Linear<>(h2, 1));
112 
113  advantageNetwork->Add(new Linear<>(h1, h2));
114  advantageNetwork->Add(new ReLULayer<>());
115  advantageNetwork->Add(new Linear<>(h2, outputDim));
116  }
117 
118  concat = new Concat<>(true);
119  concat->Add(valueNetwork);
120  concat->Add(advantageNetwork);
121 
122  completeNetwork.Add(new IdentityLayer<>());
123  completeNetwork.Add(featureNetwork);
124  completeNetwork.Add(concat);
125  this->ResetParameters();
126  }
127 
128  DuelingDQN(FeatureNetworkType featureNetwork,
129  AdvantageNetworkType advantageNetwork,
130  ValueNetworkType valueNetwork,
131  const bool isNoisy = false):
132  featureNetwork(std::move(featureNetwork)),
133  advantageNetwork(std::move(advantageNetwork)),
134  valueNetwork(std::move(valueNetwork)),
135  isNoisy(isNoisy)
136  {
137  concat = new Concat<>(true);
138  concat->Add(valueNetwork);
139  concat->Add(advantageNetwork);
140  completeNetwork.Add(new IdentityLayer<>());
141  completeNetwork.Add(featureNetwork);
142  completeNetwork.Add(concat);
143  this->ResetParameters();
144  }
145 
147  DuelingDQN(const DuelingDQN& model) : isNoisy(false)
148  { /* Nothing to do here. */ }
149 
151  void operator = (const DuelingDQN& model)
152  {
153  *valueNetwork = *model.valueNetwork;
154  *advantageNetwork = *model.advantageNetwork;
155  *featureNetwork = *model.featureNetwork;
156  isNoisy = model.isNoisy;
157  noisyLayerIndex = model.noisyLayerIndex;
158  }
159 
171  void Predict(const arma::mat state, arma::mat& actionValue)
172  {
173  arma::mat advantage, value, networkOutput;
174  completeNetwork.Predict(state, networkOutput);
175  value = networkOutput.row(0);
176  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
177  actionValue = advantage.each_row() +
178  (value - arma::mean(advantage));
179  }
180 
187  void Forward(const arma::mat state, arma::mat& actionValue)
188  {
189  arma::mat advantage, value, networkOutput;
190  completeNetwork.Forward(state, networkOutput);
191  value = networkOutput.row(0);
192  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193  actionValue = advantage.each_row() +
194  (value - arma::mean(advantage));
195  this->actionValues = actionValue;
196  }
197 
205  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
206  {
207  arma::mat gradLoss;
208  lossFunction.Backward(this->actionValues, target, gradLoss);
209 
210  arma::mat gradValue = arma::sum(gradLoss);
211  arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
212 
213  arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
214  completeNetwork.Backward(state, grad, gradient);
215  }
216 
221  {
222  completeNetwork.ResetParameters();
223  }
224 
228  void ResetNoise()
229  {
230  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
231  {
232  boost::get<NoisyLinear<>*>
233  (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
234  boost::get<NoisyLinear<>*>
235  (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
236  }
237  }
238 
240  const arma::mat& Parameters() const { return completeNetwork.Parameters(); }
242  arma::mat& Parameters() { return completeNetwork.Parameters(); }
243 
244  private:
246  CompleteNetworkType completeNetwork;
247 
249  Concat<>* concat;
250 
252  FeatureNetworkType* featureNetwork;
253 
255  AdvantageNetworkType* advantageNetwork;
256 
258  ValueNetworkType* valueNetwork;
259 
261  bool isNoisy;
262 
264  std::vector<size_t> noisyLayerIndex;
265 
267  arma::mat actionValues;
268 
270  MeanSquaredError<> lossFunction;
271 };
272 
273 } // namespace rl
274 } // namespace mlpack
275 
276 #endif
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false)
Construct an instance of DuelingDQN class.
Definition: dueling_dqn.hpp:79
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:82
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: prereqs.hpp:55
Implementation of the Dueling Deep Q-Learning network.
Definition: dueling_dqn.hpp:52
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:35
Implementation of the base layer.
Definition: base_layer.hpp:65
DuelingDQN()
Default constructor.
Definition: dueling_dqn.hpp:56
DuelingDQN(const DuelingDQN &model)
Copy constructor.
DuelingDQN(FeatureNetworkType featureNetwork, AdvantageNetworkType advantageNetwork, ValueNetworkType valueNetwork, const bool isNoisy=false)
Implementation of the Concat class.
Definition: concat.hpp:45
Implementation of the NoisyLinear layer class.
Definition: layer_types.hpp:96
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
The mean squared error performance function measures the network&#39;s performance according to the mean ...
void ResetParameters()
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
const arma::mat & Parameters() const
Return the Parameters.
void Add(Args... args)
Definition: sequential.hpp:142
Implementation of the Sequential class.
void Add(Args... args)
Definition: concat.hpp:147
This class is used to initialize weigth matrix with a gaussian.