Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NoDeath wrapper #374

Merged
merged 4 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ lastpage:
.. autoclass:: minigrid.wrappers.FullyObsWrapper
```

# No Death

```{eval-rst}
.. autoclass:: minigrid.wrappers.NoDeath
```

# Observation

```{eval-rst}
Expand Down
76 changes: 76 additions & 0 deletions minigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,3 +788,79 @@ def action(self, action):
return self.np_random.integers(0, high=6)
else:
return self.random_action


class NoDeath(Wrapper):
"""
Wrapper to prevent death in specific cells (e.g., lava cells).
Instead of dying, the agent will receive a negative reward.

Example:
>>> import gymnasium as gym
>>> from minigrid.wrappers import NoDeath
>>>
>>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
>>> _, _ = env.reset(seed=2)
>>> _, _, _, _, _ = env.step(1)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(0, True)
>>>
>>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
>>> _, _ = env.reset(seed=2)
>>> _, _, _, _, _ = env.step(1)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-1.0, False)
>>>
>>>
>>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
>>> _, _ = env.reset(seed=2)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-1, True)
>>>
>>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
>>> _, _ = env.reset(seed=2)
>>> _, reward, term, *_ = env.step(2)
>>> reward, term
(-2.0, False)
"""

def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
"""A wrapper to prevent death in specific cells.

Args:
env: The environment to apply the wrapper
no_death_types: List of strings to identify death cells
death_cost: The negative reward received in death cells

"""
assert "goal" not in no_death_types, "goal cannot be a death cell"

super().__init__(env)
self.death_cost = death_cost
self.no_death_types = no_death_types

def step(self, action):
# In Dynamic-Obstacles, obstacles move after the agent moves,
# so we need to check for collision before self.env.step()
front_cell = self.grid.get(*self.front_pos)
going_to_death = (
action == self.actions.forward
and front_cell is not None
and front_cell.type in self.no_death_types
)

obs, reward, terminated, truncated, info = self.env.step(action)

# We also check if the agent stays in death cells (e.g., lava)
# without moving
current_cell = self.grid.get(*self.agent_pos)
in_death = current_cell is not None and current_cell.type in self.no_death_types

if terminated and (going_to_death or in_death):
terminated = False
reward += self.death_cost

return obs, reward, terminated, truncated, info
33 changes: 33 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FlatObsWrapper,
FullyObsWrapper,
ImgObsWrapper,
NoDeath,
OneHotPartialObsWrapper,
PositionBonus,
ReseedWrapper,
Expand Down Expand Up @@ -356,3 +357,35 @@ def test_dict_observation_space_doesnt_clash_with_one_hot():
assert obs["image"].shape == (7, 7, 20)
assert env.observation_space["image"].shape == (7, 7, 20)
env.close()


def test_no_death_wrapper():
death_cost = -1

env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
_, _ = env.reset(seed=2)
_, _, _, _, _ = env.step(1)
_, reward, term, *_ = env.step(2)

env_wrap = NoDeath(env, ("lava",), death_cost)
_, _ = env_wrap.reset(seed=2)
_, _, _, _, _ = env_wrap.step(1)
_, reward_wrap, term_wrap, *_ = env_wrap.step(2)

assert term and not term_wrap
assert reward_wrap == reward + death_cost
env.close()
env_wrap.close()

env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
_, _ = env.reset(seed=2)
_, reward, term, *_ = env.step(2)

env = NoDeath(env, ("ball",), death_cost)
_, _ = env.reset(seed=2)
_, reward_wrap, term_wrap, *_ = env.step(2)

assert term and not term_wrap
assert reward_wrap == reward + death_cost
env.close()
env_wrap.close()