double_pole_cart.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace rl {
20 
28 {
29  public:
35  class State
36  {
37  public:
41  State() : data(dimension)
42  { /* Nothing to do here. */ }
43 
49  State(const arma::colvec& data) : data(data)
50  { /* Nothing to do here */ }
51 
53  arma::colvec Data() const { return data; }
55  arma::colvec& Data() { return data; }
56 
58  double Position() const { return data[0]; }
60  double& Position() { return data[0]; }
61 
63  double Velocity() const { return data[1]; }
65  double& Velocity() { return data[1]; }
66 
68  double Angle(const size_t i) const { return data[2 * i]; }
70  double& Angle(const size_t i) { return data[2 * i]; }
71 
73  double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
75  double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
76 
78  const arma::colvec& Encode() const { return data; }
79 
81  static constexpr size_t dimension = 6;
82 
83  private:
85  arma::colvec data;
86  };
87 
91  enum Action
92  {
95 
96  // Track the size of the action space.
98  };
99 
117  DoublePoleCart(const double m1 = 0.1,
118  const double m2 = 0.01,
119  const double l1 = 0.5,
120  const double l2 = 0.05,
121  const double gravity = 9.8,
122  const double massCart = 1.0,
123  const double forceMag = 10.0,
124  const double tau = 0.02,
125  const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
126  const double xThreshold = 2.4,
127  const double doneReward = 0.0,
128  const size_t maxSteps = 0) :
129  m1(m1),
130  m2(m2),
131  l1(l1),
132  l2(l2),
133  gravity(gravity),
134  massCart(massCart),
135  forceMag(forceMag),
136  tau(tau),
137  thetaThresholdRadians(thetaThresholdRadians),
138  xThreshold(xThreshold),
139  doneReward(doneReward),
140  maxSteps(maxSteps),
141  stepsPerformed(0)
142  { /* Nothing to do here */ }
143 
153  double Sample(const State& state,
154  const Action& action,
155  State& nextState)
156  {
157  // Update the number of steps performed.
158  stepsPerformed++;
159 
160  arma::vec dydx(6, arma::fill::zeros);
161  dydx[0] = state.Velocity();
162  dydx[2] = state.AngularVelocity(1);
163  dydx[4] = state.AngularVelocity(2);
164  Dsdt(state, action, dydx);
165  RK4(state, action, dydx, nextState);
166 
167  // Check if the episode has terminated.
168  bool done = IsTerminal(nextState);
169 
170  // Do not reward agent if it failed.
171  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
172  return doneReward;
173  else if (done)
174  return 0;
175 
180  return 1.0;
181  }
182 
191  void Dsdt(const State& state,
192  const Action& action,
193  arma::vec& dydx)
194  {
195  double totalForce = action ? forceMag : -forceMag;
196  double totalMass = massCart;
197  double omega1 = state.AngularVelocity(1);
198  double omega2 = state.AngularVelocity(2);
199  double sinTheta1 = std::sin(state.Angle(1));
200  double sinTheta2 = std::sin(state.Angle(2));
201  double cosTheta1 = std::cos(state.Angle(1));
202  double cosTheta2 = std::cos(state.Angle(2));
203 
204  // Calculate total effective force.
205  totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206  std::sin(2 * state.Angle(1));
207  totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208  std::sin(2 * state.Angle(2));
209 
210  // Calculate total effective mass.
211  totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212  totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
213 
214  // Calculate acceleration.
215  double xAcc = totalForce / totalMass;
216  dydx[1] = xAcc;
217 
218  // Calculate angular acceleration.
219  dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220  dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
221  }
222 
232  void RK4(const State& state,
233  const Action& action,
234  arma::vec& dydx,
235  State& nextState)
236  {
237  const double hh = tau * 0.5;
238  const double h6 = tau / 6;
239  arma::vec yt(6);
240  arma::vec dyt(6);
241  arma::vec dym(6);
242 
243  yt = state.Data() + (hh * dydx);
244  Dsdt(State(yt), action, dyt);
245  dyt[0] = yt[1];
246  dyt[2] = yt[3];
247  dyt[4] = yt[5];
248  yt = state.Data() + (hh * dyt);
249 
250  Dsdt(State(yt), action, dym);
251  dym[0] = yt[1];
252  dym[2] = yt[3];
253  dym[4] = yt[5];
254  yt = state.Data() + (tau * dym);
255  dym += dyt;
256 
257  Dsdt(State(yt), action, dyt);
258  dyt[0] = yt[1];
259  dyt[2] = yt[3];
260  dyt[4] = yt[5];
261  nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
262  }
263 
272  double Sample(const State& state, const Action& action)
273  {
274  State nextState;
275  return Sample(state, action, nextState);
276  }
277 
284  {
285  stepsPerformed = 0;
286  return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
287  }
288 
295  bool IsTerminal(const State& state) const
296  {
297  if (maxSteps != 0 && stepsPerformed >= maxSteps)
298  {
299  Log::Info << "Episode terminated due to the maximum number of steps"
300  "being taken.";
301  return true;
302  }
303  if (std::abs(state.Position()) > xThreshold)
304  {
305  Log::Info << "Episode terminated due to cart crossing threshold";
306  return true;
307  }
308  if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
309  std::abs(state.Angle(2)) > thetaThresholdRadians)
310  {
311  Log::Info << "Episode terminated due to pole falling";
312  return true;
313  }
314  return false;
315  }
316 
318  size_t StepsPerformed() const { return stepsPerformed; }
319 
321  size_t MaxSteps() const { return maxSteps; }
323  size_t& MaxSteps() { return maxSteps; }
324 
325  private:
327  double m1;
328 
330  double m2;
331 
333  double l1;
334 
336  double l2;
337 
339  double gravity;
340 
342  double massCart;
343 
345  double forceMag;
346 
348  double tau;
349 
351  double thetaThresholdRadians;
352 
354  double xThreshold;
355 
357  double doneReward;
358 
360  size_t maxSteps;
361 
363  size_t stepsPerformed;
364 };
365 
366 } // namespace rl
367 } // namespace mlpack
368 
369 #endif
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_po.hpp:21
Implementation of Double Pole Cart Balancing task.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Position() const
Get the position of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
size_t StepsPerformed() const
Get the number of steps performed.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
arma::colvec & Data()
Modify the internal representation of the state.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Double Pole Cart.
State()
Construct a state instance.
DoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
double Velocity() const
Get the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
Action
Implementation of action of Double Pole Cart.
arma::colvec Data() const
Get the internal representation of the state.