mountain_car.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
17 #define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
18 
19 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace rl {
23 
28 {
29  public:
34  class State
35  {
36  public:
40  State(): data(dimension, arma::fill::zeros)
41  { /* Nothing to do here. */ }
42 
48  State(const arma::colvec& data): data(data)
49  { /* Nothing to do here. */ }
50 
52  arma::colvec& Data() { return data; }
53 
55  double Velocity() const { return data[0]; }
57  double& Velocity() { return data[0]; }
58 
60  double Position() const { return data[1]; }
62  double& Position() { return data[1]; }
63 
65  const arma::colvec& Encode() const { return data; }
66 
68  static constexpr size_t dimension = 2;
69 
70  private:
72  arma::colvec data;
73  };
74 
78  enum Action
79  {
83 
86  };
87 
96  MountainCar(const double positionMin = -1.2,
97  const double positionMax = 0.5,
98  const double velocityMin = -0.07,
99  const double velocityMax = 0.07,
100  const double doneReward = 0) :
101  positionMin(positionMin),
102  positionMax(positionMax),
103  velocityMin(velocityMin),
104  velocityMax(velocityMax),
105  doneReward(doneReward)
106  { /* Nothing to do here */ }
107 
117  double Sample(const State& state,
118  const Action& action,
119  State& nextState) const
120  {
121  // Calculate acceleration.
122  int direction = action - 1;
123  nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 *
124  std::cos(3 * state.Position());
125  nextState.Velocity() = std::min(
126  std::max(nextState.Velocity(), velocityMin), velocityMax);
127 
128  // Update states.
129  nextState.Position() = state.Position() + nextState.Velocity();
130  nextState.Position() = std::min(
131  std::max(nextState.Position(), positionMin), positionMax);
132 
133  if (std::abs(nextState.Position() - positionMin) <= 1e-5)
134  {
135  nextState.Velocity() = 0.0;
136  }
137  bool done = IsTerminal(nextState);
145  if (done)
146  return doneReward;
147  return -1.0;
148  }
149 
158  double Sample(const State& state, const Action& action) const
159  {
160  State nextState;
161  return Sample(state, action, nextState);
162  }
163 
171  {
172  State state;
173  state.Velocity() = 0.0;
174  state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6;
175  return state;
176  }
177 
184  bool IsTerminal(const State& state) const
185  {
186  return std::abs(state.Position() - positionMax) <= 1e-5;
187  }
188 
189  private:
191  double positionMin;
192 
194  double positionMax;
195 
197  double velocityMin;
198 
200  double velocityMax;
201 
203  double doneReward;
204 };
205 
206 } // namespace rl
207 } // namespace mlpack
208 
209 #endif
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Mountain Car.
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
.hpp
Definition: add_to_po.hpp:21
MountainCar(const double positionMin=-1.2, const double positionMax=0.5, const double velocityMin=-0.07, const double velocityMax=0.07, const double doneReward=0)
Construct a Mountain Car instance using the given constant.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const arma::colvec & Encode() const
Encode the state to a column vector.
double & Position()
Modify the position.
double & Velocity()
Modify the velocity.
State InitialSample() const
Initial position is randomly generated within [-0.6, -0.4].
State(const arma::colvec &data)
Construct a state based on the given data.
Track the size of the action space.
Action
Implementation of action of Mountain Car.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of state of Mountain Car.
State()
Construct a state instance.
Implementation of Mountain Car task.
double Sample(const State &state, const Action &action) const
Dynamics of Mountain Car.
double Position() const
Get the position.
static constexpr size_t dimension
Dimension of the encoded state.
double Velocity() const
Get the velocity.