cart_pole.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace rl {
22 
26 class CartPole
27 {
28  public:
33  class State
34  {
35  public:
39  State() : data(dimension)
40  { /* Nothing to do here. */ }
41 
47  State(const arma::colvec& data) : data(data)
48  { /* Nothing to do here */ }
49 
51  arma::colvec& Data() { return data; }
52 
54  double Position() const { return data[0]; }
56  double& Position() { return data[0]; }
57 
59  double Velocity() const { return data[1]; }
61  double& Velocity() { return data[1]; }
62 
64  double Angle() const { return data[2]; }
66  double& Angle() { return data[2]; }
67 
69  double AngularVelocity() const { return data[3]; }
71  double& AngularVelocity() { return data[3]; }
72 
74  const arma::colvec& Encode() const { return data; }
75 
77  static constexpr size_t dimension = 4;
78 
79  private:
81  arma::colvec data;
82  };
83 
87  enum Action
88  {
91 
92  // Track the size of the action space.
94  };
95 
108  CartPole(const double gravity = 9.8,
109  const double massCart = 1.0,
110  const double massPole = 0.1,
111  const double length = 0.5,
112  const double forceMag = 10.0,
113  const double tau = 0.02,
114  const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
115  const double xThreshold = 2.4) :
116  gravity(gravity),
117  massCart(massCart),
118  massPole(massPole),
119  totalMass(massCart + massPole),
120  length(length),
121  poleMassLength(massPole * length),
122  forceMag(forceMag),
123  tau(tau),
124  thetaThresholdRadians(thetaThresholdRadians),
125  xThreshold(xThreshold)
126  { /* Nothing to do here */ }
127 
137  double Sample(const State& state,
138  const Action& action,
139  State& nextState) const
140  {
141  // Calculate acceleration.
142  double force = action ? forceMag : -forceMag;
143  double cosTheta = std::cos(state.Angle());
144  double sinTheta = std::sin(state.Angle());
145  double temp = (force + poleMassLength * state.AngularVelocity() *
146  state.AngularVelocity() * sinTheta) / totalMass;
147  double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
148  (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
149  double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
150 
151  // Update states.
152  nextState.Position() = state.Position() + tau * state.Velocity();
153  nextState.Velocity() = state.Velocity() + tau * xAcc;
154  nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
155  nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;
156 
157  return 1.0;
158  }
159 
168  double Sample(const State& state, const Action& action) const
169  {
170  State nextState;
171  return Sample(state, action, nextState);
172  }
173 
180  {
181  return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
182  }
183 
190  bool IsTerminal(const State& state) const
191  {
192  return std::abs(state.Position()) > xThreshold ||
193  std::abs(state.Angle()) > thetaThresholdRadians;
194  }
195 
196  private:
198  double gravity;
199 
201  double massCart;
202 
204  double massPole;
205 
207  double totalMass;
208 
210  double length;
211 
213  double poleMassLength;
214 
216  double forceMag;
217 
219  double tau;
220 
222  double thetaThresholdRadians;
223 
225  double xThreshold;
226 };
227 
228 } // namespace rl
229 } // namespace mlpack
230 
231 #endif
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:137
.hpp
Definition: add_to_po.hpp:21
double Sample(const State &state, const Action &action) const
Dynamics of Cart Pole.
Definition: cart_pole.hpp:168
State()
Construct a state instance.
Definition: cart_pole.hpp:39
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:87
State InitialSample() const
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:179
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
Definition: cart_pole.hpp:190
double Position() const
Get the position.
Definition: cart_pole.hpp:54
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:33
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:108
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
Implementation of Cart Pole task.
Definition: cart_pole.hpp:26