15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 47 State(
const arma::colvec& data) : data(data)
51 arma::colvec&
Data() {
return data; }
64 double Angle()
const {
return data[2]; }
66 double&
Angle() {
return data[2]; }
74 const arma::colvec&
Encode()
const {
return data; }
109 const double massCart = 1.0,
110 const double massPole = 0.1,
111 const double length = 0.5,
112 const double forceMag = 10.0,
113 const double tau = 0.02,
114 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
115 const double xThreshold = 2.4,
116 const double doneReward = 0.0) :
120 totalMass(massCart + massPole),
122 poleMassLength(massPole * length),
125 thetaThresholdRadians(thetaThresholdRadians),
126 xThreshold(xThreshold),
127 doneReward(doneReward)
141 State& nextState)
const 144 double force = action ? forceMag : -forceMag;
145 double cosTheta = std::cos(state.
Angle());
146 double sinTheta = std::sin(state.
Angle());
149 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
150 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
151 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
184 return Sample(state, action, nextState);
194 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
205 return std::abs(state.
Position()) > xThreshold ||
206 std::abs(state.
Angle()) > thetaThresholdRadians;
226 double poleMassLength;
235 double thetaThresholdRadians;
double Velocity() const
Get the velocity.
State(const arma::colvec &data)
Construct a state instance from given data.
double AngularVelocity() const
Get the angular velocity.
double & Velocity()
Modify the velocity.
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Cart Pole instance.
double Sample(const State &state, const Action &action) const
Dynamics of Cart Pole.
State()
Construct a state instance.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
State InitialSample() const
Initial state representation is randomly generated within [-0.05, 0.05].
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
double Position() const
Get the position.
Implementation of the state of Cart Pole.
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Cart Pole instance using the given constants.
double Angle() const
Get the angle.
double & Angle()
Modify the angle.
double & Position()
Modify the position.
double & AngularVelocity()
Modify the angular velocity.
arma::colvec & Data()
Modify the internal representation of the state.
static constexpr size_t dimension
Dimension of the encoded state.
const arma::colvec & Encode() const
Encode the state to a column vector.
Implementation of Cart Pole task.