Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CartPole] Add sutton_barto_reward argument #958

Merged
merged 15 commits into from
Mar 12, 2024
36 changes: 28 additions & 8 deletions gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,14 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

## Rewards
Since the goal is to keep the pole upright for as long as possible, by default, a reward of `+1` is given for every step taken, including the termination step. The default reward threshold is 500 for v1 and 200 for v0 due to the time limit on the environment.

Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken,
including the termination step, is allotted. The threshold for rewards is 500 for v1 and 200 for v0.
If `sutton_barto_reward=True`, then a reward of `0` is awarded for every non-terminating step and `-1` for the terminating step. As a result, the reward threshold is 0 for v0 and v1.

## Starting State

All observations are assigned a uniformly random value in `(-0.05, 0.05)`

## Episode End

The episode ends if any one of the following occurs:

1. Termination: Pole Angle is greater than ±12°
Expand All @@ -87,6 +85,10 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):

Kallinteris-Andreas marked this conversation as resolved.
Show resolved Hide resolved
```

| Parameter | Type | Default | Description |
|-------------------------|------------|-------------------------|-----------------------------------------------------------------------------------------------|
| `sutton_barto_reward` | **bool** | `False` | If `True` the reward function matches the original sutton barto implementation |

## Vectorized environment

To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.
Expand All @@ -101,14 +103,23 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
SyncVectorEnv(CartPole-v1, num_envs=3)

```

## Version History
* v1: `max_time_steps` raised to 500.
- In Gymnasium `1.0.0a2` the `sutton_barto_reward` argument was added (related [GitHub issue](https://github.com/Farama-Foundation/Gymnasium/issues/790))
* v0: Initial versions release.
"""

metadata = {
"render_modes": ["human", "rgb_array"],
"render_fps": 50,
}

def __init__(self, render_mode: Optional[str] = None):
def __init__(
self, sutton_barto_reward: bool = False, render_mode: Optional[str] = None
):
self._sutton_barto_reward = sutton_barto_reward

self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
Expand Down Expand Up @@ -190,11 +201,17 @@ def step(self, action):
)

if not terminated:
reward = 1.0
if self._sutton_barto_reward:
reward = 0.0
elif not self._sutton_barto_reward:
reward = 1.0
elif self.steps_beyond_terminated is None:
# Pole just fell!
self.steps_beyond_terminated = 0
reward = 1.0
if self._sutton_barto_reward:
reward = -1.0
elif not self._sutton_barto_reward:
reward = 1.0
else:
if self.steps_beyond_terminated == 0:
logger.warn(
Expand All @@ -204,7 +221,10 @@ def step(self, action):
"True' -- any further steps are undefined behavior."
)
self.steps_beyond_terminated += 1
reward = 0.0
if self._sutton_barto_reward:
reward = -1.0
elif not self._sutton_barto_reward:
reward = 0.0

if self.render_mode == "human":
self.render()
Expand Down
Loading