12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP 13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP 29 template <
typename NetworkType = FFN<MeanSquaredError<>,
30 GaussianInitialization>>
53 const bool isNoisy =
false):
57 network.Add(
new Linear<>(inputDim, h1));
61 noisyLayerIndex.push_back(network.Model().size());
64 noisyLayerIndex.push_back(network.Model().size());
71 network.Add(
new Linear<>(h2, outputDim));
75 SimpleDQN(NetworkType network,
const bool isNoisy =
false):
76 network(
std::move(network)),
91 void Predict(
const arma::mat state, arma::mat& actionValue)
93 network.Predict(state, actionValue);
102 void Forward(
const arma::mat state, arma::mat& target)
104 network.Forward(state, target);
112 network.ResetParameters();
120 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
122 boost::get<NoisyLinear<>*>
123 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
128 const arma::mat&
Parameters()
const {
return network.Parameters(); }
139 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
141 network.Backward(state, target, gradient);
152 std::vector<size_t> noisyLayerIndex;
Artificial Neural Network.
void ResetParameters()
Resets the parameters of the network.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleDQN()
Default constructor.
arma::mat & Parameters()
Modify the Parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the base layer.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the NoisyLinear layer class.
SimpleDQN(NetworkType network, const bool isNoisy=false)
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false)
Construct an instance of SimpleDQN class.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
The mean squared error performance function measures the network's performance according to the mean ...
This class is used to initialize weigth matrix with a gaussian.