12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP 13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP 63 concat->
Add(valueNetwork);
64 concat->
Add(advantageNetwork);
66 completeNetwork.Add(featureNetwork);
67 completeNetwork.Add(concat);
83 const bool isNoisy =
false):
84 completeNetwork(
EmptyLoss<>(), GaussianInitialization(0, 0.001)),
96 noisyLayerIndex.push_back(valueNetwork->Model().size());
103 noisyLayerIndex.push_back(valueNetwork->Model().size());
109 valueNetwork->Add(
new Linear<>(h1, h2));
111 valueNetwork->Add(
new Linear<>(h2, 1));
115 advantageNetwork->
Add(
new Linear<>(h2, outputDim));
119 concat->
Add(valueNetwork);
120 concat->Add(advantageNetwork);
123 completeNetwork.Add(featureNetwork);
124 completeNetwork.Add(concat);
125 this->ResetParameters();
129 AdvantageNetworkType advantageNetwork,
130 ValueNetworkType valueNetwork,
131 const bool isNoisy =
false):
132 featureNetwork(
std::move(featureNetwork)),
133 advantageNetwork(
std::move(advantageNetwork)),
134 valueNetwork(
std::move(valueNetwork)),
138 concat->
Add(valueNetwork);
139 concat->Add(advantageNetwork);
141 completeNetwork.Add(featureNetwork);
142 completeNetwork.Add(concat);
143 this->ResetParameters();
153 *valueNetwork = *model.valueNetwork;
154 *advantageNetwork = *model.advantageNetwork;
155 *featureNetwork = *model.featureNetwork;
156 isNoisy = model.isNoisy;
157 noisyLayerIndex = model.noisyLayerIndex;
171 void Predict(
const arma::mat state, arma::mat& actionValue)
173 arma::mat advantage, value, networkOutput;
174 completeNetwork.Predict(state, networkOutput);
175 value = networkOutput.row(0);
176 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
177 actionValue = advantage.each_row() +
178 (value - arma::mean(advantage));
187 void Forward(
const arma::mat state, arma::mat& actionValue)
189 arma::mat advantage, value, networkOutput;
190 completeNetwork.Forward(state, networkOutput);
191 value = networkOutput.row(0);
192 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193 actionValue = advantage.each_row() +
194 (value - arma::mean(advantage));
195 this->actionValues = actionValue;
205 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
208 lossFunction.Backward(this->actionValues, target, gradLoss);
210 arma::mat gradValue = arma::sum(gradLoss);
211 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
213 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
214 completeNetwork.Backward(state, grad, gradient);
222 completeNetwork.ResetParameters();
230 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
232 boost::get<NoisyLinear<>*>
233 (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
234 boost::get<NoisyLinear<>*>
235 (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
240 const arma::mat&
Parameters()
const {
return completeNetwork.Parameters(); }
242 arma::mat&
Parameters() {
return completeNetwork.Parameters(); }
246 CompleteNetworkType completeNetwork;
252 FeatureNetworkType* featureNetwork;
255 AdvantageNetworkType* advantageNetwork;
258 ValueNetworkType* valueNetwork;
264 std::vector<size_t> noisyLayerIndex;
267 arma::mat actionValues;
Artificial Neural Network.
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false)
Construct an instance of DuelingDQN class.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the Dueling Deep Q-Learning network.
The empty loss does nothing, letting the user calculate the loss outside the model.
Implementation of the base layer.
DuelingDQN()
Default constructor.
DuelingDQN(const DuelingDQN &model)
Copy constructor.
DuelingDQN(FeatureNetworkType featureNetwork, AdvantageNetworkType advantageNetwork, ValueNetworkType valueNetwork, const bool isNoisy=false)
Implementation of the Concat class.
Implementation of the NoisyLinear layer class.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
The mean squared error performance function measures the network's performance according to the mean ...
void ResetParameters()
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
Implementation of a standard feed forward network.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the Sequential class.
This class is used to initialize weigth matrix with a gaussian.