Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make autoreset wrapper return 2 on reset #123

Merged
merged 5 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def render(self, state: State) -> Any:
class AutoResetWrapper(Wrapper):
"""Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.FIRST, respectively.
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`),
which would lead to inefficient computation due to both the `step` and `reset` functions
Expand All @@ -380,8 +380,7 @@ def _auto_reset(

# Replace observation with reset observation.
timestep = timestep.replace( # type: ignore
observation=reset_timestep.observation,
step_type=reset_timestep.step_type,
observation=reset_timestep.observation
)

return state, timestep
Expand Down Expand Up @@ -481,8 +480,7 @@ def _auto_reset(

# Replace observation with reset observation.
timestep = timestep.replace( # type: ignore
observation=reset_timestep.observation,
step_type=reset_timestep.step_type,
observation=reset_timestep.observation
)

return state, timestep
Expand Down
5 changes: 2 additions & 3 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def test_auto_reset_wrapper__step_reset(
action = fake_auto_reset_environment.action_spec().generate_value()
state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action)

assert timestep.step_type == first_timestep.step_type == StepType.FIRST
assert timestep.step_type == StepType.LAST
chex.assert_trees_all_equal(timestep.observation, first_timestep.observation)


Expand Down Expand Up @@ -676,8 +676,7 @@ def test_vmap_auto_reset_wrapper__step_reset(
state, action
)

assert jnp.all(timestep.step_type == first_timestep.step_type)
assert jnp.all(timestep.step_type == StepType.FIRST)
assert jnp.all(timestep.step_type == StepType.LAST)
chex.assert_trees_all_equal(timestep.observation, first_timestep.observation)

def test_vmap_auto_reset_wrapper__step(
Expand Down