mountain_car.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
17 #define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
18 
19 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace rl {
23 
28 {
29  public:
34  class State
35  {
36  public:
40  State(): data(dimension, arma::fill::zeros)
41  { /* Nothing to do here. */ }
42 
48  State(const arma::colvec& data): data(data)
49  { /* Nothing to do here. */ }
50 
52  arma::colvec& Data() { return data; }
53 
55  double Velocity() const { return data[0]; }
57  double& Velocity() { return data[0]; }
58 
60  double Position() const { return data[1]; }
62  double& Position() { return data[1]; }
63 
65  const arma::colvec& Encode() const { return data; }
66 
68  static constexpr size_t dimension = 2;
69 
70  private:
72  arma::colvec data;
73  };
74 
78  enum Action
79  {
83 
86  };
87 
97  MountainCar(const double positionMin = -1.2,
98  const double positionMax = 0.6,
99  const double positionGoal = 0.5,
100  const double velocityMin = -0.07,
101  const double velocityMax = 0.07,
102  const double doneReward = 0) :
103  positionMin(positionMin),
104  positionMax(positionMax),
105  positionGoal(positionGoal),
106  velocityMin(velocityMin),
107  velocityMax(velocityMax),
108  doneReward(doneReward)
109  { /* Nothing to do here */ }
110 
120  double Sample(const State& state,
121  const Action& action,
122  State& nextState) const
123  {
124  // Calculate acceleration.
125  int direction = action - 1;
126  nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 *
127  std::cos(3 * state.Position());
128  nextState.Velocity() = std::min(
129  std::max(nextState.Velocity(), velocityMin), velocityMax);
130 
131  // Update states.
132  nextState.Position() = state.Position() + nextState.Velocity();
133  nextState.Position() = std::min(
134  std::max(nextState.Position(), positionMin), positionMax);
135 
136  if (nextState.Position() == positionMin && nextState.Velocity() < 0)
137  nextState.Velocity() = 0.0;
138 
139  bool done = IsTerminal(nextState);
147  if (done)
148  return doneReward;
149  return -1.0;
150  }
151 
160  double Sample(const State& state, const Action& action) const
161  {
162  State nextState;
163  return Sample(state, action, nextState);
164  }
165 
173  {
174  State state;
175  state.Velocity() = 0.0;
176  state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6;
177  return state;
178  }
179 
186  bool IsTerminal(const State& state) const
187  {
188  return state.Position() >= positionGoal;
189  }
190 
191  private:
193  double positionMin;
194 
196  double positionMax;
197 
199  double positionGoal;
200 
202  double velocityMin;
203 
205  double velocityMax;
206 
208  double doneReward;
209 };
210 
211 } // namespace rl
212 } // namespace mlpack
213 
214 #endif
double Sample(const State &state, const Action &action, State &nextState) const
Dynamics of Mountain Car.
bool IsTerminal(const State &state) const
Whether given state is a terminal state.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
MountainCar(const double positionMin=-1.2, const double positionMax=0.6, const double positionGoal=0.5, const double velocityMin=-0.07, const double velocityMax=0.07, const double doneReward=0)
Construct a Mountain Car instance using the given constant.
const arma::colvec & Encode() const
Encode the state to a column vector.
double & Position()
Modify the position.
double & Velocity()
Modify the velocity.
State InitialSample() const
Initial position is randomly generated within [-0.6, -0.4].
State(const arma::colvec &data)
Construct a state based on the given data.
Track the size of the action space.
Action
Implementation of action of Mountain Car.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of state of Mountain Car.
State()
Construct a state instance.
Implementation of Mountain Car task.
double Sample(const State &state, const Action &action) const
Dynamics of Mountain Car.
double Position() const
Get the position.
static constexpr size_t dimension
Dimension of the encoded state.
double Velocity() const
Get the velocity.