cart_pole.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace rl {
22 
26 class CartPole
27 {
28  public:
33  class State
34  {
35  public:
39  State() : data(dimension)
40  { /* Nothing to do here. */ }
41 
47  State(const arma::colvec& data) : data(data)
48  { /* Nothing to do here */ }
49 
51  arma::colvec& Data() { return data; }
52 
54  double Position() const { return data[0]; }
56  double& Position() { return data[0]; }
57 
59  double Velocity() const { return data[1]; }
61  double& Velocity() { return data[1]; }
62 
64  double Angle() const { return data[2]; }
66  double& Angle() { return data[2]; }
67 
69  double AngularVelocity() const { return data[3]; }
71  double& AngularVelocity() { return data[3]; }
72 
74  const arma::colvec& Encode() const { return data; }
75 
77  static constexpr size_t dimension = 4;
78 
79  private:
81  arma::colvec data;
82  };
83 
87  enum Action
88  {
91 
92  // Track the size of the action space.
94  };
95 
111  CartPole(const double gravity = 9.8,
112  const double massCart = 1.0,
113  const double massPole = 0.1,
114  const double length = 0.5,
115  const double forceMag = 10.0,
116  const double tau = 0.02,
117  const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
118  const double xThreshold = 2.4,
119  const double doneReward = 0.0,
120  const size_t maxSteps = 0) :
121  gravity(gravity),
122  massCart(massCart),
123  massPole(massPole),
124  totalMass(massCart + massPole),
125  length(length),
126  poleMassLength(massPole * length),
127  forceMag(forceMag),
128  tau(tau),
129  thetaThresholdRadians(thetaThresholdRadians),
130  xThreshold(xThreshold),
131  doneReward(doneReward),
132  maxSteps(maxSteps),
133  stepsPerformed(0)
134  { /* Nothing to do here */ }
135 
145  double Sample(const State& state,
146  const Action& action,
147  State& nextState)
148  {
149  // Update the number of steps performed.
150  stepsPerformed++;
151 
152  // Calculate acceleration.
153  double force = action ? forceMag : -forceMag;
154  double cosTheta = std::cos(state.Angle());
155  double sinTheta = std::sin(state.Angle());
156  double temp = (force + poleMassLength * state.AngularVelocity() *
157  state.AngularVelocity() * sinTheta) / totalMass;
158  double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
159  (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
160  double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
161 
162  // Update states.
163  nextState.Position() = state.Position() + tau * state.Velocity();
164  nextState.Velocity() = state.Velocity() + tau * xAcc;
165  nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
166  nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;
167 
168  // Check if the episode has terminated.
169  bool done = IsTerminal(nextState);
170 
171  // Do not reward agent if it failed.
172  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
173  return doneReward;
174  else if (done)
175  return 0;
176 
181  return 1.0;
182  }
183 
192  double Sample(const State& state, const Action& action)
193  {
194  State nextState;
195  return Sample(state, action, nextState);
196  }
197 
204  {
205  stepsPerformed = 0;
206  return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
207  }
208 
215  bool IsTerminal(const State& state) const
216  {
217  if (maxSteps != 0 && stepsPerformed >= maxSteps)
218  {
219  Log::Info << "Episode terminated due to the maximum number of steps"
220  "being taken.";
221  return true;
222  }
223  else if (std::abs(state.Position()) > xThreshold ||
224  std::abs(state.Angle()) > thetaThresholdRadians)
225  {
226  Log::Info << "Episode terminated due to agent failing.";
227  return true;
228  }
229  return false;
230  }
231 
233  size_t StepsPerformed() const { return stepsPerformed; }
234 
236  size_t MaxSteps() const { return maxSteps; }
238  size_t& MaxSteps() { return maxSteps; }
239 
240  private:
242  double gravity;
243 
245  double massCart;
246 
248  double massPole;
249 
251  double totalMass;
252 
254  double length;
255 
257  double poleMassLength;
258 
260  double forceMag;
261 
263  double tau;
264 
266  double thetaThresholdRadians;
267 
269  double xThreshold;
270 
272  double doneReward;
273 
275  size_t maxSteps;
276 
278  size_t stepsPerformed;
279 };
280 
281 } // namespace rl
282 } // namespace mlpack
283 
284 #endif
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
strip_type.hpp
Definition: add_to_po.hpp:21
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
Definition: cart_pole.hpp:192
State()
Construct a state instance.
Definition: cart_pole.hpp:39
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:87
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:145
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
Definition: cart_pole.hpp:215
double Position() const
Get the position.
Definition: cart_pole.hpp:54
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:33
size_t & MaxSteps()
Set the maximum number of steps allowed.
Definition: cart_pole.hpp:238
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:111
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
Implementation of Cart Pole task.
Definition: cart_pole.hpp:26
size_t StepsPerformed() const
Get the number of steps performed.
Definition: cart_pole.hpp:233
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:203
size_t MaxSteps() const
Get the maximum number of steps allowed.
Definition: cart_pole.hpp:236