cart_pole.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace rl {
22 
26 class CartPole
27 {
28  public:
33  class State
34  {
35  public:
39  State() : data(dimension)
40  { /* Nothing to do here. */ }
41 
47  State(const arma::colvec& data) : data(data)
48  { /* Nothing to do here */ }
49 
51  arma::colvec& Data() { return data; }
52 
54  double Position() const { return data[0]; }
56  double& Position() { return data[0]; }
57 
59  double Velocity() const { return data[1]; }
61  double& Velocity() { return data[1]; }
62 
64  double Angle() const { return data[2]; }
66  double& Angle() { return data[2]; }
67 
69  double AngularVelocity() const { return data[3]; }
71  double& AngularVelocity() { return data[3]; }
72 
74  const arma::colvec& Encode() const { return data; }
75 
77  static constexpr size_t dimension = 4;
78 
79  private:
81  arma::colvec data;
82  };
83 
87  enum Action
88  {
91 
92  // Track the size of the action space.
94  };
95 
108  CartPole(const double gravity = 9.8,
109  const double massCart = 1.0,
110  const double massPole = 0.1,
111  const double length = 0.5,
112  const double forceMag = 10.0,
113  const double tau = 0.02,
114  const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
115  const double xThreshold = 2.4,
116  const double doneReward = 0.0) :
117  gravity(gravity),
118  massCart(massCart),
119  massPole(massPole),
120  totalMass(massCart + massPole),
121  length(length),
122  poleMassLength(massPole * length),
123  forceMag(forceMag),
124  tau(tau),
125  thetaThresholdRadians(thetaThresholdRadians),
126  xThreshold(xThreshold),
127  doneReward(doneReward)
128  { /* Nothing to do here */ }
129 
139  double Sample(const State& state,
140  const Action& action,
141  State& nextState) const
142  {
143  // Calculate acceleration.
144  double force = action ? forceMag : -forceMag;
145  double cosTheta = std::cos(state.Angle());
146  double sinTheta = std::sin(state.Angle());
147  double temp = (force + poleMassLength * state.AngularVelocity() *
148  state.AngularVelocity() * sinTheta) / totalMass;
149  double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
150  (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
151  double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
152 
153  // Update states.
154  nextState.Position() = state.Position() + tau * state.Velocity();
155  nextState.Velocity() = state.Velocity() + tau * xAcc;
156  nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
157  nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;
158 
163  bool done = IsTerminal(nextState);
164  if (done)
165  return doneReward;
170  return 1.0;
171  }
172 
181  double Sample(const State& state, const Action& action) const
182  {
183  State nextState;
184  return Sample(state, action, nextState);
185  }
186 
193  {
194  return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
195  }
196 
203  bool IsTerminal(const State& state) const
204  {
205  return std::abs(state.Position()) > xThreshold ||
206  std::abs(state.Angle()) > thetaThresholdRadians;
207  }
208 
209  private:
211  double gravity;
212 
214  double massCart;
215 
217  double massPole;
218 
220  double totalMass;
221 
223  double length;
224 
226  double poleMassLength;
227 
229  double forceMag;
230 
232  double tau;
233 
235  double thetaThresholdRadians;
236 
238  double xThreshold;
239 
241  double doneReward;
242 };
243 
244 } // namespace rl
245 } // namespace mlpack
246 
247 #endif
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:139
.hpp
Definition: add_to_po.hpp:21
double Sample(const State &state, const Action &action) const
Dynamics of Cart Pole.
Definition: cart_pole.hpp:181
State()
Construct a state instance.
Definition: cart_pole.hpp:39
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:87
State InitialSample() const
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:192
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
Definition: cart_pole.hpp:203
double Position() const
Get the position.
Definition: cart_pole.hpp:54
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:33
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:108
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
Implementation of Cart Pole task.
Definition: cart_pole.hpp:26