continuous_double_pole_cart.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
15 #define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace rl {
21 
29 {
30  public:
36  class State
37  {
38  public:
42  State() : data(dimension)
43  { /* Nothing to do here. */ }
44 
50  State(const arma::colvec& data) : data(data)
51  { /* Nothing to do here */ }
52 
54  arma::colvec Data() const { return data; }
56  arma::colvec& Data() { return data; }
57 
59  double Position() const { return data[0]; }
61  double& Position() { return data[0]; }
62 
64  double Velocity() const { return data[1]; }
66  double& Velocity() { return data[1]; }
67 
69  double Angle(const size_t i) const { return data[2 * i]; }
71  double& Angle(const size_t i) { return data[2 * i]; }
72 
74  double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
76  double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
77 
79  const arma::colvec& Encode() const { return data; }
80 
82  static constexpr size_t dimension = 6;
83 
84  private:
86  arma::colvec data;
87  };
88 
92  struct Action
93  {
94  double action[1];
95  // Storing degree of freedom
96  const int size = 1;
97  };
98 
117  ContinuousDoublePoleCart(const double m1 = 0.1,
118  const double m2 = 0.01,
119  const double l1 = 0.5,
120  const double l2 = 0.05,
121  const double gravity = 9.8,
122  const double massCart = 1.0,
123  const double forceMag = 10.0,
124  const double tau = 0.02,
125  const double thetaThresholdRadians = 36 * 2 *
126  3.1416 / 360,
127  const double xThreshold = 2.4,
128  const double doneReward = 0.0,
129  const size_t maxSteps = 0) :
130  m1(m1),
131  m2(m2),
132  l1(l1),
133  l2(l2),
134  gravity(gravity),
135  massCart(massCart),
136  forceMag(forceMag),
137  tau(tau),
138  thetaThresholdRadians(thetaThresholdRadians),
139  xThreshold(xThreshold),
140  doneReward(doneReward),
141  maxSteps(maxSteps),
142  stepsPerformed(0)
143  { /* Nothing to do here */ }
144 
154  double Sample(const State& state,
155  const Action& action,
156  State& nextState)
157  {
158  // Update the number of steps performed.
159  stepsPerformed++;
160 
161  arma::vec dydx(6, arma::fill::zeros);
162  dydx[0] = state.Velocity();
163  dydx[2] = state.AngularVelocity(1);
164  dydx[4] = state.AngularVelocity(2);
165  Dsdt(state, action, dydx);
166  RK4(state, action, dydx, nextState);
167 
168  // Check if the episode has terminated.
169  bool done = IsTerminal(nextState);
170 
171  // Do not reward agent if it failed.
172  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
173  return doneReward;
174  else if (done)
175  return 0;
176 
181  return 1.0;
182  }
183 
192  void Dsdt(const State& state,
193  const Action& action,
194  arma::vec& dydx)
195  {
196  double totalForce = action.action[0];
197  double totalMass = massCart;
198  double omega1 = state.AngularVelocity(1);
199  double omega2 = state.AngularVelocity(2);
200  double sinTheta1 = std::sin(state.Angle(1));
201  double sinTheta2 = std::sin(state.Angle(2));
202  double cosTheta1 = std::cos(state.Angle(1));
203  double cosTheta2 = std::cos(state.Angle(2));
204 
205  // Calculate total effective force.
206  totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
207  std::sin(2 * state.Angle(1));
208  totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
209  std::sin(2 * state.Angle(2));
210 
211  // Calculate total effective mass.
212  totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
213  totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
214 
215  // Calculate acceleration.
216  double xAcc = totalForce / totalMass;
217  dydx[1] = xAcc;
218 
219  // Calculate angular acceleration.
220  dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
221  dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
222  }
223 
233  void RK4(const State& state,
234  const Action& action,
235  arma::vec& dydx,
236  State& nextState)
237  {
238  const double hh = tau * 0.5;
239  const double h6 = tau / 6;
240  arma::vec yt(6);
241  arma::vec dyt(6);
242  arma::vec dym(6);
243 
244  yt = state.Data() + (hh * dydx);
245  Dsdt(State(yt), action, dyt);
246  dyt[0] = yt[1];
247  dyt[2] = yt[3];
248  dyt[4] = yt[5];
249  yt = state.Data() + (hh * dyt);
250 
251  Dsdt(State(yt), action, dym);
252  dym[0] = yt[1];
253  dym[2] = yt[3];
254  dym[4] = yt[5];
255  yt = state.Data() + (tau * dym);
256  dym += dyt;
257 
258  Dsdt(State(yt), action, dyt);
259  dyt[0] = yt[1];
260  dyt[2] = yt[3];
261  dyt[4] = yt[5];
262  nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
263  }
264 
273  double Sample(const State& state, const Action& action)
274  {
275  State nextState;
276  return Sample(state, action, nextState);
277  }
278 
285  {
286  stepsPerformed = 0;
287  return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
288  }
289 
296  bool IsTerminal(const State& state) const
297  {
298  if (maxSteps != 0 && stepsPerformed >= maxSteps)
299  {
300  Log::Info << "Episode terminated due to the maximum number of steps"
301  "being taken.";
302  return true;
303  }
304  if (std::abs(state.Position()) > xThreshold)
305  {
306  Log::Info << "Episode terminated due to cart crossing threshold";
307  return true;
308  }
309  if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
310  std::abs(state.Angle(2)) > thetaThresholdRadians)
311  {
312  Log::Info << "Episode terminated due to pole falling";
313  return true;
314  }
315  return false;
316  }
317 
319  size_t StepsPerformed() const { return stepsPerformed; }
320 
322  size_t MaxSteps() const { return maxSteps; }
324  size_t& MaxSteps() { return maxSteps; }
325 
326  private:
328  double m1;
329 
331  double m2;
332 
334  double l1;
335 
337  double l2;
338 
340  double gravity;
341 
343  double massCart;
344 
346  double forceMag;
347 
349  double tau;
350 
352  double thetaThresholdRadians;
353 
355  double xThreshold;
356 
358  double doneReward;
359 
361  size_t maxSteps;
362 
364  size_t stepsPerformed;
365 };
366 
367 } // namespace rl
368 } // namespace mlpack
369 
370 #endif
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
strip_type.hpp
Definition: add_to_po.hpp:21
Implementation of action of Continuous Double Pole Cart.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t StepsPerformed() const
Get the number of steps performed.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Continuous Double Pole Cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double Position() const
Get the position of the cart.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
double Velocity() const
Get the velocity of the cart.
arma::colvec Data() const
Get the internal representation of the state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.