diff --git a/test/test_transforms.py b/test/test_transforms.py index 36408bf4964..c9d2fb8c031 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10583,6 +10583,7 @@ def test_multistep_transform(self): outs_2 = [] td = env.reset().contiguous() + assert "reward" not in td for _ in range(1): rollout = env.rollout( 250, auto_reset=False, tensordict=td, break_when_any_done=False @@ -10626,7 +10627,7 @@ def test_multistep_transform(self): ).contiguous() assert "reward" not in rollout.keys() out = t._inv_call(rollout) - td = rollout[..., -1]["next"] + td = rollout[..., -1]["next"].exclude("reward") if out is not None: outs_3.append(out)