simple_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace rl {
23 
24 using namespace mlpack::ann;
25 
29 template <typename NetworkType = FFN<MeanSquaredError<>,
30  GaussianInitialization>>
31 class SimpleDQN
32 {
33  public:
37  SimpleDQN() : network(), isNoisy(false)
38  { /* Nothing to do here. */ }
39 
49  SimpleDQN(const int inputDim,
50  const int h1,
51  const int h2,
52  const int outputDim,
53  const bool isNoisy = false):
54  network(MeanSquaredError<>(), GaussianInitialization(0, 0.001)),
55  isNoisy(isNoisy)
56  {
57  network.Add(new Linear<>(inputDim, h1));
58  network.Add(new ReLULayer<>());
59  if (isNoisy)
60  {
61  noisyLayerIndex.push_back(network.Model().size());
62  network.Add(new NoisyLinear<>(h1, h2));
63  network.Add(new ReLULayer<>());
64  noisyLayerIndex.push_back(network.Model().size());
65  network.Add(new NoisyLinear<>(h2, outputDim));
66  }
67  else
68  {
69  network.Add(new Linear<>(h1, h2));
70  network.Add(new ReLULayer<>());
71  network.Add(new Linear<>(h2, outputDim));
72  }
73  }
74 
75  SimpleDQN(NetworkType network, const bool isNoisy = false):
76  network(std::move(network)),
77  isNoisy(isNoisy)
78  { /* Nothing to do here. */ }
79 
91  void Predict(const arma::mat state, arma::mat& actionValue)
92  {
93  network.Predict(state, actionValue);
94  }
95 
102  void Forward(const arma::mat state, arma::mat& target)
103  {
104  network.Forward(state, target);
105  }
106 
111  {
112  network.ResetParameters();
113  }
114 
118  void ResetNoise()
119  {
120  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
121  {
122  boost::get<NoisyLinear<>*>
123  (network.Model()[noisyLayerIndex[i]])->ResetNoise();
124  }
125  }
126 
128  const arma::mat& Parameters() const { return network.Parameters(); }
130  arma::mat& Parameters() { return network.Parameters(); }
131 
139  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
140  {
141  network.Backward(state, target, gradient);
142  }
143 
144  private:
146  NetworkType network;
147 
149  bool isNoisy;
150 
152  std::vector<size_t> noisyLayerIndex;
153 };
154 
155 } // namespace rl
156 } // namespace mlpack
157 
158 #endif
Artificial Neural Network.
void ResetParameters()
Resets the parameters of the network.
Definition: simple_dqn.hpp:110
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: simple_dqn.hpp:118
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
SimpleDQN()
Default constructor.
Definition: simple_dqn.hpp:37
arma::mat & Parameters()
Modify the Parameters.
Definition: simple_dqn.hpp:130
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:82
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Definition: simple_dqn.hpp:102
Definition: prereqs.hpp:55
const arma::mat & Parameters() const
Return the Parameters.
Definition: simple_dqn.hpp:128
Implementation of the base layer.
Definition: base_layer.hpp:65
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: simple_dqn.hpp:91
Implementation of the NoisyLinear layer class.
Definition: layer_types.hpp:96
SimpleDQN(NetworkType network, const bool isNoisy=false)
Definition: simple_dqn.hpp:75
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false)
Construct an instance of SimpleDQN class.
Definition: simple_dqn.hpp:49
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: simple_dqn.hpp:139
The mean squared error performance function measures the network&#39;s performance according to the mean ...
This class is used to initialize weigth matrix with a gaussian.