15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 47 State(
const arma::colvec& data) : data(data)
51 arma::colvec&
Data() {
return data; }
64 double Angle()
const {
return data[2]; }
66 double&
Angle() {
return data[2]; }
74 const arma::colvec&
Encode()
const {
return data; }
112 const double massCart = 1.0,
113 const double massPole = 0.1,
114 const double length = 0.5,
115 const double forceMag = 10.0,
116 const double tau = 0.02,
117 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
118 const double xThreshold = 2.4,
119 const double doneReward = 0.0,
120 const size_t maxSteps = 0) :
124 totalMass(massCart + massPole),
126 poleMassLength(massPole * length),
129 thetaThresholdRadians(thetaThresholdRadians),
130 xThreshold(xThreshold),
131 doneReward(doneReward),
153 double force = action ? forceMag : -forceMag;
154 double cosTheta = std::cos(state.
Angle());
155 double sinTheta = std::sin(state.
Angle());
158 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
159 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
160 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
172 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
195 return Sample(state, action, nextState);
206 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
217 if (maxSteps != 0 && stepsPerformed >= maxSteps)
219 Log::Info <<
"Episode terminated due to the maximum number of steps" 223 else if (std::abs(state.
Position()) > xThreshold ||
224 std::abs(state.
Angle()) > thetaThresholdRadians)
226 Log::Info <<
"Episode terminated due to agent failing.";
257 double poleMassLength;
266 double thetaThresholdRadians;
278 size_t stepsPerformed;
double Velocity() const
Get the velocity.
State(const arma::colvec &data)
Construct a state instance from given data.
double AngularVelocity() const
Get the angular velocity.
double & Velocity()
Modify the velocity.
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
State()
Construct a state instance.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Action
Implementation of action of Cart Pole.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
double Position() const
Get the position.
Implementation of the state of Cart Pole.
size_t & MaxSteps()
Set the maximum number of steps allowed.
CartPole(const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Cart Pole instance using the given constants.
double Angle() const
Get the angle.
double & Angle()
Modify the angle.
double & Position()
Modify the position.
double & AngularVelocity()
Modify the angular velocity.
arma::colvec & Data()
Modify the internal representation of the state.
static constexpr size_t dimension
Dimension of the encoded state.
const arma::colvec & Encode() const
Encode the state to a column vector.
Implementation of Cart Pole task.
size_t StepsPerformed() const
Get the number of steps performed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.