Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] ParallelEnv handling of done flag #788

Merged
merged 7 commits into from
Jan 4, 2023
Merged

Conversation

matteobettini
Copy link
Contributor

@matteobettini matteobettini commented Jan 3, 2023

Description

This PR addresses issue #776.

It does two things:

  1. The ParallelEnv workers will not try to test the "done" flag coming out of a wrapped environment. This worked only when the done flag was without a batch_size. Now, on the other hand, the workers reports the step result as is (the done message was not even used anyway)
  2. Very important issue. The workers now call env.reset() instead of env._reset(). this is because some environments do not return the done key in the tensordict returned by _reset(). Since later the worker calls if _td.get("done").any(): this call fails if the environment does not return the done key in the tensordict returned by _reset(). reset() on the other hand wraps _reset() and automatically adds the defaulte done value, solving the issue.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2023
@vmoens
Copy link
Contributor

vmoens commented Jan 3, 2023

  1. Very important issue. The workers now call env.reset() instead of env._reset(). this is because some environments do not return the done key in the tensordict returned by _reset(). Since later the worker calls if _td.get("done").any(): this call fails if the environment does not return the done key in the tensordict returned by _reset(). reset() on the other hand wraps _reset() and automatically adds the defaulte done value, solving the issue.

I'm curious: can't we use default?
As of now default can't be a callable, so we could do

done = _td.get("done", None)
if done is None:
    done = torch.zeros(_td.shape, dtype=torch.bool, device=device)

Since what we want is to call any() we can even do simpler than that

if _td.get("done", torch.zeros([], dtype=torch.bool)).any():
    etc.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests are failing, I suspect that it's because of the _reset to reset change.
Can we try with the False default?

@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 3, 2023

Some tests are failing, I suspect that it's because of the _reset to reset change.
Can we try with the False default?

Yes your options definitely work! I thought that that defaulting was all that reset() does plus simple checks, but maybe was wrong. Will fix

@matteobettini
Copy link
Contributor Author

matteobettini commented Jan 3, 2023

@vmoens
Apparently just seeting the default in the if is not enough as later the done is checked again in

 while self.shared_tensordict_parent.get("done").any():
            if check_count == 4:
                raise RuntimeError("Envs have just been reset but some are still done")

Thus i now tried as it is done in reset() to set the default.

In any case, I plan to propose some bigger modifications to this type of checking of the done with any(). Because with vectorized environments it doesn't make sense. So I think these parts might be changed in the near future.

@codecov
Copy link

codecov bot commented Jan 3, 2023

Codecov Report

Merging #788 (7602e09) into main (5b9ff55) will increase coverage by 0.01%.
The diff coverage is 0.00%.

❗ Current head 7602e09 differs from pull request most recent head d21705d. Consider uploading reports for the commit d21705d to get more accurate results

@@            Coverage Diff             @@
##             main     #788      +/-   ##
==========================================
+ Coverage   88.74%   88.76%   +0.01%     
==========================================
  Files         123      123              
  Lines       21170    21168       -2     
==========================================
+ Hits        18787    18789       +2     
+ Misses       2383     2379       -4     
Flag Coverage Δ
habitat-gpu 24.76% <0.00%> (+<0.01%) ⬆️
linux-brax 29.36% <0.00%> (+<0.01%) ⬆️
linux-cpu 85.23% <0.00%> (+<0.01%) ⬆️
linux-gpu 86.22% <0.00%> (+0.01%) ⬆️
linux-jumanji 30.13% <0.00%> (+<0.01%) ⬆️
linux-outdeps-gpu 72.26% <0.00%> (+<0.01%) ⬆️
linux-stable-cpu 85.09% <0.00%> (+<0.01%) ⬆️
linux-stable-gpu 85.87% <0.00%> (+0.01%) ⬆️
linux_examples-gpu 42.70% <0.00%> (+<0.01%) ⬆️
macos-cpu 84.99% <0.00%> (+<0.01%) ⬆️
olddeps-gpu 76.03% <0.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchrl/envs/vec_env.py 69.48% <0.00%> (+0.56%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@@ -1004,6 +1002,10 @@ def _run_worker_pipe_shared_mem(
raise RuntimeError("call 'init' before resetting")
# _td = tensordict.select("observation").to(env.device).clone()
_td = env._reset(**reset_kwargs)
_td.set_default(
"done",
torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=_td.device),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one issue with set_default is that we're creating the tensor at every single call, thereby allocating memory that is then collected if unused.
Depending on the kind of env you're running this may or may not cause some overhead. In practice, a boolean tensor is quite cheap so I might be overthinking it.
The solutions might be:

  1. leave it as it is
  2. use torch.zeros(1, ...).expand(*_td.batch_size, 1) which uses less memory but requires more operations
  3. (related to 1) implement a setdefault with callable, though this has been discussed by the python community and they decided not to implement it for some valid reasons. Since we'd like to stay as close as we can from dict I would not vote for it (besides, we have a set_default method when it should be setdefault 😱)
  4. do a try / except KeyError but the issue here is that try / except causes an overhead of its own when the except is reached (which should happen less often than not, thereby reversing the memory overhead that occurs say 90%. of the time to a compute time overhead that occurs 10% of the time)
  5. check if the key is there, and if not populate. This is basically what set_default does so I would not expect this to be slower, but obviously less 'atomic'.

Also tensordict do not always have a device (ie device can be None). This is to enable tensordict that collect all the state_dict of a module distributed across several GPUs, for instance.
Not sure it is relevant here as the tensordict will have the device of the env and the env must have a device, so it's unlikely that it will cause any bug any time soon.

Copy link
Contributor Author

@matteobettini matteobettini Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave the choice to you. Whatever we choose we should also update the normal EnvBase.reset() then, as this is what is done there https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L454

My personal opinion is that all this checks on the done flag should be removed

@vmoens vmoens merged commit 51ab9ab into pytorch:main Jan 4, 2023
@vmoens vmoens added the bug Something isn't working label Jan 4, 2023
@matteobettini matteobettini deleted the done branch January 4, 2023 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants