-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] ParallelEnv
handling of done flag
#788
Conversation
I'm curious: can't we use default? done = _td.get("done", None)
if done is None:
done = torch.zeros(_td.shape, dtype=torch.bool, device=device) Since what we want is to call if _td.get("done", torch.zeros([], dtype=torch.bool)).any():
etc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests are failing, I suspect that it's because of the _reset
to reset
change.
Can we try with the False
default?
Yes your options definitely work! I thought that that defaulting was all that reset() does plus simple checks, but maybe was wrong. Will fix |
@vmoens while self.shared_tensordict_parent.get("done").any():
if check_count == 4:
raise RuntimeError("Envs have just been reset but some are still done") Thus i now tried as it is done in In any case, I plan to propose some bigger modifications to this type of checking of the |
Codecov Report
@@ Coverage Diff @@
## main #788 +/- ##
==========================================
+ Coverage 88.74% 88.76% +0.01%
==========================================
Files 123 123
Lines 21170 21168 -2
==========================================
+ Hits 18787 18789 +2
+ Misses 2383 2379 -4
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
torchrl/envs/vec_env.py
Outdated
@@ -1004,6 +1002,10 @@ def _run_worker_pipe_shared_mem( | |||
raise RuntimeError("call 'init' before resetting") | |||
# _td = tensordict.select("observation").to(env.device).clone() | |||
_td = env._reset(**reset_kwargs) | |||
_td.set_default( | |||
"done", | |||
torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=_td.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one issue with set_default is that we're creating the tensor at every single call, thereby allocating memory that is then collected if unused.
Depending on the kind of env you're running this may or may not cause some overhead. In practice, a boolean tensor is quite cheap so I might be overthinking it.
The solutions might be:
- leave it as it is
- use
torch.zeros(1, ...).expand(*_td.batch_size, 1)
which uses less memory but requires more operations - (related to 1) implement a
setdefault
with callable, though this has been discussed by the python community and they decided not to implement it for some valid reasons. Since we'd like to stay as close as we can from dict I would not vote for it (besides, we have aset_default
method when it should besetdefault
😱) - do a
try / except KeyError
but the issue here is thattry / except
causes an overhead of its own when the except is reached (which should happen less often than not, thereby reversing the memory overhead that occurs say 90%. of the time to a compute time overhead that occurs 10% of the time) - check if the key is there, and if not populate. This is basically what set_default does so I would not expect this to be slower, but obviously less 'atomic'.
Also tensordict do not always have a device (ie device can be None). This is to enable tensordict that collect all the state_dict of a module distributed across several GPUs, for instance.
Not sure it is relevant here as the tensordict will have the device of the env and the env must have a device, so it's unlikely that it will cause any bug any time soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave the choice to you. Whatever we choose we should also update the normal EnvBase.reset()
then, as this is what is done there https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L454
My personal opinion is that all this checks on the done flag should be removed
Description
This PR addresses issue #776.
It does two things:
ParallelEnv
workers will not try to test the "done" flag coming out of a wrapped environment. This worked only when the done flag was without a batch_size. Now, on the other hand, the workers reports the step result as is (the done message was not even used anyway)env.reset()
instead ofenv._reset()
. this is because some environments do not return the done key in the tensordict returned by_reset()
. Since later the worker callsif _td.get("done").any():
this call fails if the environment does not return the done key in the tensordict returned by_reset()
.reset()
on the other hand wraps_reset()
and automatically adds the defaulte done value, solving the issue.