diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index e70bd666c..72f38be0b 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -354,7 +354,7 @@ def render(self, state: State) -> Any: class AutoResetWrapper(Wrapper): """Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the - terminal TimeStep is reset to the reset observation and StepType.FIRST, respectively. + terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. WARNING: do not `jax.vmap` the wrapped environment (e.g. do not use with the `VmapWrapper`), which would lead to inefficient computation due to both the `step` and `reset` functions @@ -380,8 +380,7 @@ def _auto_reset( # Replace observation with reset observation. timestep = timestep.replace( # type: ignore - observation=reset_timestep.observation, - step_type=reset_timestep.step_type, + observation=reset_timestep.observation ) return state, timestep @@ -481,8 +480,7 @@ def _auto_reset( # Replace observation with reset observation. timestep = timestep.replace( # type: ignore - observation=reset_timestep.observation, - step_type=reset_timestep.step_type, + observation=reset_timestep.observation ) return state, timestep diff --git a/jumanji/wrappers_test.py b/jumanji/wrappers_test.py index 27d77a87c..4dc339694 100644 --- a/jumanji/wrappers_test.py +++ b/jumanji/wrappers_test.py @@ -576,7 +576,7 @@ def test_auto_reset_wrapper__step_reset( action = fake_auto_reset_environment.action_spec().generate_value() state, timestep = jax.jit(fake_auto_reset_environment.step)(state, action) - assert timestep.step_type == first_timestep.step_type == StepType.FIRST + assert timestep.step_type == StepType.LAST chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) @@ -676,8 +676,7 @@ def test_vmap_auto_reset_wrapper__step_reset( state, action ) - assert jnp.all(timestep.step_type == first_timestep.step_type) - assert jnp.all(timestep.step_type == StepType.FIRST) + assert jnp.all(timestep.step_type == StepType.LAST) chex.assert_trees_all_equal(timestep.observation, first_timestep.observation) def test_vmap_auto_reset_wrapper__step(